class TableConfigurer(HasTraits): columns = List children = List(TabularAdapter) available_columns = List sparse_columns = List adapter = Instance(TabularAdapter) id = 'table' font = Enum(*SIZES) auto_set = Bool(False) fontsize_enabled = Bool(True) title = Str('Configure Table') refresh_func = Callable show_all = Button('Show All') set_sparse = Button('Define Sparse') toggle_sparse = Button('Toggle Sparse') sparse_enabled = Property(depends_on='columns[]') _toggle_sparse_enabled = Bool(False) default_button = Button('Default') defaults_path = Str def _get_sparse_enabled(self): return self.sparse_columns != self.columns and len(self.columns) < 5 def __init__(self, *args, **kw): super(TableConfigurer, self).__init__(*args, **kw) if self.auto_set: self.on_trait_change(self.update, 'font, columns[]') self._load_state() def closed(self, is_ok): if is_ok: self.dump() self.set_columns() self.set_font() def load(self): self._load_state() def dump(self): self._dump_state() def update(self): self.set_font() self.set_columns() def set_font(self): if self.adapter: font = 'arial {}'.format(self.font) self.adapter.font = font for ci in self.children: ci.font = font if self.refresh_func: self.refresh_func() # self.refresh_table_needed = True def set_columns(self): # def _columns_changed(self): if self.adapter: cols = self._assemble_columns() for ci in self.children: ci.columns = cols cols = [ci for ci in cols if ci in self.adapter.all_columns] self.adapter.columns = cols def _set_font(self, f): s = f.pointSize() self.font = s def _load_state(self): p = os.path.join(paths.hidden_dir, self.id) if os.path.isfile(p): try: with open(p, 'rb') as rfile: state = pickle.load(rfile) except (pickle.PickleError, OSError, EOFError, TraitError): return try: self.sparse_columns = state.get('sparse_columns') except: pass cols = state.get('columns') if cols: ncols = [] for ai in self.available_columns: if ai in cols: ncols.append(ai) self.columns = ncols font = state.get('font', None) if font: self.font = font self._load_hook(state) self.update() def _dump_state(self): p = os.path.join(paths.hidden_dir, self.id) obj = self._get_dump() with open(p, 'wb') as wfile: try: pickle.dump(obj, wfile) except pickle.PickleError: pass def _get_dump(self): obj = dict(columns=self.columns, font=self.font, sparse_columns=self.sparse_columns) return obj def _load_hook(self, state): pass def _assemble_columns(self): d = self.adapter.all_columns_dict return [(k, d[k]) for k, v in self.adapter.all_columns if k in self.columns] def _get_columns_grp(self): return def _set_defaults(self): p = self.defaults_path if os.path.isfile(p): import yaml with open(p, 'r') as rfile: yd = yaml.load(rfile) try: self.columns = yd['columns'] except KeyError: pass self.set_columns() def _default_button_fired(self): self._set_defaults() def _set_sparse_fired(self): self.sparse_columns = self.columns def _toggle_sparse_fired(self): if self._toggle_sparse_enabled: columns = self._prev_columns else: self._prev_columns = self.columns columns = self.sparse_columns self.columns = columns self.set_columns() self._toggle_sparse_enabled = not self._toggle_sparse_enabled def _show_all_fired(self): self.columns = self.available_columns self.set_columns() def set_adapter(self, adp): self.adapter = adp # def _adapter_changed(self, adp): # if adp: acols = [c for c, _ in adp.all_columns] # set currently visible columns t = [c for c, _ in adp.columns] cols = [c for c in acols if c in t] self.trait_set(columns=cols) # set all available columns self.available_columns = acols if adp.font: self._set_font(adp.font) self._load_state() def traits_view(self): v = okcancel_view(VGroup( HGroup( UItem('show_all', tooltip='Show all columns'), UItem( 'set_sparse', tooltip= 'Set the current set of columns to the Sparse Column Set', enabled_when='sparse_enabled'), UItem('toggle_sparse', tooltip='Display only Sparse Column Set'), UItem('default_button', tooltip='Set to Laboratory defaults. File located at ' '[root]/experiments/experiment_defaults.yaml')), VGroup(UItem('columns', style='custom', editor=CheckListEditor(name='available_columns', cols=3)), Item('font', enabled_when='fontsize_enabled'), show_border=True)), handler=TableConfigurerHandler(), title=self.title) return v
class ArraySource(Source): """A simple source that allows one to view a suitably shaped numpy array as ImageData. This supports both scalar and vector data. """ # The scalar array data we manage. scalar_data = Trait(None, _check_scalar_array, rich_compare=False) # The name of our scalar array. scalar_name = Str('scalar') # The vector array data we manage. vector_data = Trait(None, _check_vector_array, rich_compare=False) # The name of our vector array. vector_name = Str('vector') # The spacing of the points in the array. spacing = DelegatesTo('change_information_filter', 'output_spacing', desc='the spacing between points in array') # The origin of the points in the array. origin = DelegatesTo('change_information_filter', 'output_origin', desc='the origin of the points in array') # Fire an event to update the spacing and origin. This # is here for backwards compatability. Firing this is no # longer needed. update_image_data = Button('Update spacing and origin') # The image data stored by this instance. image_data = Instance(tvtk.ImageData, (), allow_none=False) # Use an ImageChangeInformation filter to reliably set the # spacing and origin on the output change_information_filter = Instance(tvtk.ImageChangeInformation, args=(), kw={ 'output_spacing': (1.0, 1.0, 1.0), 'output_origin': (0.0, 0.0, 0.0) }) # Should we transpose the input data or not. Transposing is # necessary to make the numpy array compatible with the way VTK # needs it. However, transposing numpy arrays makes them # non-contiguous where the data is copied by VTK. Thus, when the # user explicitly requests that transpose_input_array is false # then we assume that the array has already been suitably # formatted by the user. transpose_input_array = Bool( True, desc= 'if input array should be transposed (if on VTK will copy the input data)' ) # Information about what this object can produce. output_info = PipelineInfo(datasets=['image_data']) # Specify the order of dimensions. The default is: [0, 1, 2] dimensions_order = List(Int, [0, 1, 2]) # Our view. view = View( Group(Item(name='transpose_input_array'), Item(name='scalar_name'), Item(name='vector_name'), Item(name='spacing'), Item(name='origin'), show_labels=True)) ###################################################################### # `object` interface. ###################################################################### def __init__(self, **traits): # Set the scalar and vector data at the end so we pop it here. sd = traits.pop('scalar_data', None) vd = traits.pop('vector_data', None) # Now set the other traits. super(ArraySource, self).__init__(**traits) self.configure_input_data(self.change_information_filter, self.image_data) # And finally set the scalar and vector data. if sd is not None: self.scalar_data = sd if vd is not None: self.vector_data = vd self.outputs = [self.change_information_filter.output] self.on_trait_change(self._information_changed, 'spacing,origin') def __get_pure_state__(self): d = super(ArraySource, self).__get_pure_state__() d.pop('image_data', None) return d ###################################################################### # ArraySource interface. ###################################################################### def update(self): """Call this function when you change the array data in-place.""" d = self.image_data d.modified() pd = d.point_data if self.scalar_data is not None: pd.scalars.modified() if self.vector_data is not None: pd.vectors.modified() self.change_information_filter.update() self.data_changed = True ###################################################################### # Non-public interface. ###################################################################### def _image_data_changed(self, value): self.configure_input_data(self.change_information_filter, value) def _scalar_data_changed(self, data): img_data = self.image_data if data is None: img_data.point_data.scalars = None self.data_changed = True return dims = list(data.shape) if len(dims) == 2: dims.append(1) # set the dimension indices dim0, dim1, dim2 = self.dimensions_order img_data.origin = tuple(self.origin) img_data.dimensions = tuple(dims) img_data.extent = 0, dims[dim0] - 1, 0, dims[dim1] - 1, 0, dims[ dim2] - 1 if is_old_pipeline(): img_data.update_extent = 0, dims[dim0] - 1, 0, dims[ dim1] - 1, 0, dims[dim2] - 1 else: update_extent = [ 0, dims[dim0] - 1, 0, dims[dim1] - 1, 0, dims[dim2] - 1 ] self.change_information_filter.set_update_extent(update_extent) if self.transpose_input_array: img_data.point_data.scalars = numpy.ravel(numpy.transpose(data)) else: img_data.point_data.scalars = numpy.ravel(data) img_data.point_data.scalars.name = self.scalar_name # This is very important and if not done can lead to a segfault! typecode = data.dtype if is_old_pipeline(): img_data.scalar_type = array_handler.get_vtk_array_type(typecode) img_data.update() # This sets up the extents correctly. else: filter_out_info = self.change_information_filter.get_output_information( 0) img_data.set_point_data_active_scalar_info( filter_out_info, array_handler.get_vtk_array_type(typecode), -1) img_data.modified() img_data.update_traits() self.change_information_filter.update() # Now flush the mayavi pipeline. self.data_changed = True def _vector_data_changed(self, data): img_data = self.image_data if data is None: img_data.point_data.vectors = None self.data_changed = True return dims = list(data.shape) if len(dims) == 3: dims.insert(2, 1) data = numpy.reshape(data, dims) img_data.origin = tuple(self.origin) img_data.dimensions = tuple(dims[:-1]) img_data.extent = 0, dims[0] - 1, 0, dims[1] - 1, 0, dims[2] - 1 if is_old_pipeline(): img_data.update_extent = 0, dims[0] - 1, 0, dims[1] - 1, 0, dims[ 2] - 1 else: self.change_information_filter.update_information() update_extent = [0, dims[0] - 1, 0, dims[1] - 1, 0, dims[2] - 1] self.change_information_filter.set_update_extent(update_extent) sz = numpy.size(data) if self.transpose_input_array: data_t = numpy.transpose(data, (2, 1, 0, 3)) else: data_t = data img_data.point_data.vectors = numpy.reshape(data_t, (sz // 3, 3)) img_data.point_data.vectors.name = self.vector_name if is_old_pipeline(): img_data.update() # This sets up the extents correctly. else: img_data.modified() img_data.update_traits() self.change_information_filter.update() # Now flush the mayavi pipeline. self.data_changed = True def _scalar_name_changed(self, value): if self.scalar_data is not None: self.image_data.point_data.scalars.name = value self.data_changed = True def _vector_name_changed(self, value): if self.vector_data is not None: self.image_data.point_data.vectors.name = value self.data_changed = True def _transpose_input_array_changed(self, value): if self.scalar_data is not None: self._scalar_data_changed(self.scalar_data) if self.vector_data is not None: self._vector_data_changed(self.vector_data) def _information_changed(self): self.change_information_filter.update() self.data_changed = True
class ODMR(ManagedJob, GetSetItemsMixin): """Provides ODMR measurements.""" # starting and stopping keep_data = Bool( False) # helper variable to decide whether to keep existing data resubmit_button = Button( label='resubmit', desc= 'Submits the measurement to the job manager. Tries to keep previously acquired data. Behaves like a normal submit if sequence or time bins have changed since previous run.' ) # measurement parameters """switch = Enum( 'mw_a', 'mw_b','mw_c', desc='switch to use for different microwave source', label='switch' )""" power = Range(low=-100., high=25., value=-12, desc='Power [dBm]', label='Power [dBm]', mode='text', auto_set=False, enter_set=True) frequency_begin = Range(low=1, high=20e9, value=2.82e9, desc='Start Frequency [Hz]', label='Begin [Hz]', editor=TextEditor(auto_set=False, enter_set=True, evaluate=float, format_str='%e')) frequency_end = Range(low=1, high=20e9, value=2.88e9, desc='Stop Frequency [Hz]', label='End [Hz]', editor=TextEditor(auto_set=False, enter_set=True, evaluate=float, format_str='%e')) frequency_delta = Range(low=1e-3, high=20e9, value=1e6, desc='frequency step [Hz]', label='Delta [Hz]', editor=TextEditor(auto_set=False, enter_set=True, evaluate=float, format_str='%e')) t_pi = Range(low=1., high=100000., value=1000., desc='length of pi pulse [ns]', label='pi [ns]', mode='text', auto_set=False, enter_set=True) laser = Range(low=1., high=10000., value=300., desc='laser [ns]', label='laser [ns]', mode='text', auto_set=False, enter_set=True) wait = Range(low=1., high=10000., value=1000., desc='wait [ns]', label='wait [ns]', mode='text', auto_set=False, enter_set=True) pulsed = Bool(False, label='pulsed', enabled_when='state != "run"') sequence = Property(trait=List, depends_on='laser,wait,t_pi') seconds_per_point = Range(low=20e-3, high=1, value=20e-3, desc='Seconds per point', label='Seconds per point', mode='text', auto_set=False, enter_set=True) stop_time = Range( low=1., value=np.inf, desc='Time after which the experiment stops by itself [s]', label='Stop time [s]', mode='text', auto_set=False, enter_set=True) n_lines = Range(low=1, high=10000, value=50, desc='Number of lines in Matrix', label='Matrix lines', mode='text', auto_set=False, enter_set=True) # control data fitting perform_fit = Bool(False, label='perform fit') number_of_resonances = Trait( 'auto', String('auto', auto_set=False, enter_set=True), Int(10000., desc='Number of Lorentzians used in fit', label='N', auto_set=False, enter_set=True)) threshold = Range( low=-99, high=99., value=-50., desc= 'Threshold for detection of resonances [%]. The sign of the threshold specifies whether the resonances are negative or positive.', label='threshold [%]', mode='text', auto_set=False, enter_set=True) # fit result fit_parameters = Array(value=np.array((np.nan, np.nan, np.nan, np.nan))) fit_frequencies = Array(value=np.array((np.nan, )), label='frequencies [Hz]') fit_line_width = Float(np.nan, label='line width [Hz]', editor=TextEditor(evaluate=float, format_str='%.3e')) fit_contrast = Float(np.nan, label='contrast [%]', editor=TextEditor(evaluate=float, format_str='%.1f')) # measurement data frequency = Array() counts = Array() counts_matrix = Array() run_time = Float(value=0.0, desc='Run time [s]', label='Run time [s]') # plotting line_label = Instance(PlotLabel) line_data = Instance(ArrayPlotData) matrix_data = Instance(ArrayPlotData) line_plot = Instance(Plot, editor=ComponentEditor()) matrix_plot = Instance(Plot, editor=ComponentEditor()) def __init__(self): super(ODMR, self).__init__() self._create_line_plot() self._create_matrix_plot() self.on_trait_change(self._update_line_data_index, 'frequency', dispatch='ui') self.on_trait_change(self._update_line_data_value, 'counts', dispatch='ui') self.on_trait_change(self._update_line_data_fit, 'fit_parameters', dispatch='ui') self.on_trait_change(self._update_matrix_data_value, 'counts_matrix', dispatch='ui') self.on_trait_change(self._update_matrix_data_index, 'n_lines,frequency', dispatch='ui') def _counts_matrix_default(self): return np.zeros((self.n_lines, len(self.frequency))) def _frequency_default(self): return np.arange(self.frequency_begin, self.frequency_end + self.frequency_delta, self.frequency_delta) def _counts_default(self): return np.zeros(self.frequency.shape) # data acquisition def apply_parameters(self): """Apply the current parameters and decide whether to keep previous data.""" frequency = np.arange(self.frequency_begin, self.frequency_end + self.frequency_delta, self.frequency_delta) if not self.keep_data or np.any(frequency != self.frequency): self.counts = np.zeros(frequency.shape) self.run_time = 0.0 self.frequency = frequency self.keep_data = True # when job manager stops and starts the job, data should be kept. Only new submission should clear data. def _run(self): try: self.state = 'run' self.apply_parameters() if self.run_time >= self.stop_time: self.state = 'done' return # if pulsed, turn on sequence if self.pulsed: ha.PulseGenerator().Sequence(100 * [(['laser', 'aom'], self.laser), ([], self.wait), (['mw'], self.t_pi)]) else: ha.PulseGenerator().Open() n = len(self.frequency) """ ha.MicrowaveB().setOutput( self.power, np.append(self.frequency,self.frequency[0]), self.seconds_per_point) self._prepareCounter(n) """ ha.MicrowaveB().setPower(self.power) ha.MicrowaveB().initSweep( self.frequency, self.power * np.ones(self.frequency.shape)) ha.CounterB().configure(n, self.seconds_per_point, DutyCycle=0.8) time.sleep(0.5) while self.run_time < self.stop_time: start_time = time.time() if threading.currentThread().stop_request.isSet(): break ha.MicrowaveB().resetListPos() counts = ha.CounterB().run() self.run_time += time.time() - start_time self.counts += counts self.counts_matrix = np.vstack( (counts, self.counts_matrix[:-1, :])) self.trait_property_changed('counts', self.counts) """ ha.MicrowaveB().doSweep() timeout = 3. start_time = time.time() while not self._count_between_markers.ready(): time.sleep(0.1) if time.time() - start_time > timeout: print "count between markers timeout in ODMR" break counts = self._count_between_markers.getData(0) self._count_between_markers.clean() """ if self.run_time < self.stop_time: self.state = 'idle' else: self.state = 'done' ha.MicrowaveB().setOutput(None, self.frequency_begin) ha.PulseGenerator().Light() ha.CounterB().clear() except: logging.getLogger().exception('Error in odmr.') self.state = 'error' # fitting @on_trait_change('counts,perform_fit,number_of_resonances,threshold') def update_fit(self): if self.perform_fit: N = self.number_of_resonances if N != 'auto': N = int(N) try: p = fitting.fit_multiple_lorentzians(self.frequency, self.counts, N, threshold=self.threshold * 0.01) except Exception: logging.getLogger().debug('ODMR fit failed.', exc_info=True) p = np.nan * np.empty(4) else: p = np.nan * np.empty(4) self.fit_parameters = p self.fit_frequencies = p[1::3] self.fit_line_width = p[2::3].mean() self.fit_contrast = -100 * p[3::3].mean() / (np.pi * p[2::3].mean() * p[0]) # plotting def _create_line_plot(self): line_data = ArrayPlotData(frequency=np.array((0., 1.)), counts=np.array((0., 0.)), fit=np.array((0., 0.))) line_plot = Plot(line_data, padding=8, padding_left=64, padding_bottom=32) line_plot.plot(('frequency', 'counts'), style='line', color='blue') line_plot.index_axis.title = 'Frequency [MHz]' line_plot.value_axis.title = 'Fluorescence counts' line_label = PlotLabel(text='', hjustify='left', vjustify='bottom', position=[64, 32]) line_plot.overlays.append(line_label) self.line_label = line_label self.line_data = line_data self.line_plot = line_plot def _create_matrix_plot(self): matrix_data = ArrayPlotData(image=np.zeros((2, 2))) matrix_plot = Plot(matrix_data, padding=8, padding_left=64, padding_bottom=32) matrix_plot.index_axis.title = 'Frequency [MHz]' matrix_plot.value_axis.title = 'line #' matrix_plot.img_plot('image', xbounds=(self.frequency[0], self.frequency[-1]), ybounds=(0, self.n_lines), colormap=Spectral) self.matrix_data = matrix_data self.matrix_plot = matrix_plot def _perform_fit_changed(self, new): plot = self.line_plot if new: plot.plot(('frequency', 'fit'), style='line', color='red', name='fit') self.line_label.visible = True else: plot.delplot('fit') self.line_label.visible = False plot.request_redraw() def _update_line_data_index(self): self.line_data.set_data('frequency', self.frequency * 1e-6) self.counts_matrix = self._counts_matrix_default() def _update_line_data_value(self): self.line_data.set_data('counts', self.counts) def _update_line_data_fit(self): if not np.isnan(self.fit_parameters[0]): self.line_data.set_data( 'fit', fitting.NLorentzians(*self.fit_parameters)(self.frequency)) p = self.fit_parameters f = p[1::3] w = p[2::3] c = 100 * p[3::3] / (np.pi * w * p[0]) s = '' for i, fi in enumerate(f): s += 'f %i: %.6e Hz, HWHM %.3e Hz, contrast %.1f%%\n' % ( i + 1, fi, w[i], c[i]) self.line_label.text = s def _update_matrix_data_value(self): self.matrix_data.set_data('image', self.counts_matrix) def _update_matrix_data_index(self): if self.n_lines > self.counts_matrix.shape[0]: self.counts_matrix = np.vstack( (self.counts_matrix, np.zeros((self.n_lines - self.counts_matrix.shape[0], self.counts_matrix.shape[1])))) else: self.counts_matrix = self.counts_matrix[:self.n_lines] self.matrix_plot.components[0].index.set_data( (self.frequency[0] * 1e-6, self.frequency[-1] * 1e-6), (0.0, float(self.n_lines))) # saving data def save_line_plot(self, filename): self.save_figure(self.line_plot, filename) def save_matrix_plot(self, filename): self.save_figure(self.matrix_plot, filename) def save_all(self, filename): self.save_line_plot(filename + '_ODMR_Line_Plot.png') self.save_matrix_plot(filename + '_ODMR_Matrix_Plot.png') self.save(filename + '_ODMR.pys') # react to GUI events def submit(self): """Submit the job to the JobManager.""" self.keep_data = False ManagedJob.submit(self) def resubmit(self): """Submit the job to the JobManager.""" self.keep_data = True ManagedJob.submit(self) def _resubmit_button_fired(self): """React to start button. Submit the Job.""" self.resubmit() traits_view = View(VGroup( HGroup( Item('submit_button', show_label=False), Item('remove_button', show_label=False), Item('resubmit_button', show_label=False), Item('priority'), Item('state', style='readonly'), Item('run_time', style='readonly', format_str='%.f'), ), Group( VGroup(HGroup( Item('power', width=-40), Item('frequency_begin', width=-80), Item('frequency_end', width=-80), Item('frequency_delta', width=-80), Item('pulsed'), Item('t_pi', width=-80), ), HGroup( Item('perform_fit'), Item('number_of_resonances'), Item('threshold'), Item('n_lines'), Item('stop_time'), ), HGroup( Item('fit_contrast', style='readonly'), Item('fit_line_width', style='readonly'), Item('fit_frequencies', style='readonly'), ), label='parameters'), VGroup(HGroup(Item('seconds_per_point'), ), HGroup( Item('laser', width=40), Item('wait', width=40), ), label='settings'), layout='tabbed', ), VSplit( Item('matrix_plot', show_label=False, resizable=True), Item('line_plot', show_label=False, resizable=True), ), ), menubar=MenuBar( Menu(Action(action='saveLinePlot', name='SaveLinePlot (.png)'), Action(action='saveMatrixPlot', name='SaveMatrixPlot (.png)'), Action(action='save', name='Save (.pyd or .pys)'), Action(action='saveAll', name='Save All (.png+.pys)'), Action(action='export', name='Export as Ascii (.asc)'), Action(action='load', name='Load'), Action(action='_on_close', name='Quit'), name='File')), title='ODMR', width=900, height=800, buttons=[], resizable=True, handler=ODMRHandler) get_set_items = [ 'frequency', 'counts', 'counts_matrix', 'fit_parameters', 'fit_contrast', 'fit_line_width', 'fit_frequencies', 'perform_fit', 'run_time', 'power', 'frequency_begin', 'frequency_end', 'frequency_delta', 'laser', 'wait', 'pulsed', 't_pi', 'seconds_per_point', 'stop_time', 'n_lines', 'number_of_resonances', 'threshold', '__doc__' ]
class DataFrameAnalyzerView(ModelView): """ Flexible ModelView class for a DataFrameAnalyzer. The view is built using many methods to build each component of the view so it can easily be subclassed and customized. TODO: add traits events to pass update/refresh notifications to the DFEditors once we have updated TraitsUI. TODO: Add traits events to receive notifications that a column/row was clicked/double-clicked. """ #: Model being viewed model = Instance(DataFrameAnalyzer) #: Selected list of data columns to display and analyze visible_columns = List(Str) #: Check box to hide/show what stats are included in the summary DF show_summary_controls = Bool #: Show the summary categorical df show_categorical_summary = Bool(True) #: Check box to hide/show what columns to analyze (panel when few columns) show_column_controls = Bool #: Open control for what columns to analyze (popup when many columns) open_column_controls = Button("Show column control") #: Button to launch the plotter tool when plotter_layout=popup plotter_launcher = Button("Launch Plot Tool") # Plotting tool attributes ------------------------------------------------ #: Does the UI expose a DF plotter? include_plotter = Bool #: Plot manager view to display. Ignored if include_plotter is False. plotter = Instance(DataFramePlotManagerView) # Styling and branding attributes ----------------------------------------- #: String describing the font to use, or dict mapping column names to font fonts = Either(Str, Dict) #: Name of the font to use if same across all columns font_name = Str(DEFAULT_FONT) #: Size of the font to use if same across all columns font_size = Int(14) #: Number of digits to display in the tables display_precision = Int(-1) #: Formatting to use to include formats = Either(Str, Dict) #: UI title for the Data section data_section_title = Str("Data") #: Exploration group label: visible only when plotter_layout="Tabbed" exploration_group_label = Str("Exploration Tools") #: Plotting group label: visible only when plotter_layout="Tabbed" plotting_group_label = Str("Plotting Tools") #: UI title for the data summary section summary_section_title = Str #: UI title for the categorical data summary section cat_summary_section_title = Str("Categorical data summary") #: UI title for the column list section column_list_section_title = Str("Column content") #: UI title for the summary content section summary_content_section_title = Str("Summary content") #: UI summary group (tab) name for numerical columns num_summary_group_name = Str("Numerical data") #: UI summary group (tab) name for categorical columns cat_summary_group_name = Str("Categorical data") #: Text to display in title bar of the containing window (if applicable) app_title = Str("Tabular Data Analyzer") #: How to place the plotter tool with respect to the exploration tool? plotter_layout = Enum("Tabbed", "HSplit", "VSplit", "popup") #: DFPlotManager traits to customize it plotter_kw = Dict #: Message displayed below the table if truncated truncation_msg = Property(Str, depends_on="model.num_displayed_rows") # Functionality controls -------------------------------------------------- #: Button to shuffle the order of the filtered data shuffle_button = Button("Shuffle") show_shuffle_button = Bool(True) #: Button to display more rows in the data table show_more_button = Button #: Button to display all rows in the data table show_all_button = Button("Show All") #: Apply button for the filter if model not in auto-apply mode apply_filter_button = ToolbarButton(image=apply_img) #: Edit the filter in a pop-out dialog pop_out_filter_button = ToolbarButton(image=pop_out_img) #: Whether to support saving, and loading filters filter_manager = Bool #: Button to launch filter expression manager to load an existing filter load_filter_button = ToolbarButton(image=load_img) #: Button to save current filter expression save_filter_button = ToolbarButton(image=save_img) #: Button to launch filter expression manager to modify saved filters manage_filter_button = ToolbarButton(image=manage_img) #: List of saved filtered expressions _known_expr = Property(Set, depends_on="model.known_filter_exps") #: Show the bottom panel with the summary of the data: _show_summary = Bool(True) allow_show_summary = Bool(True) #: Button to export the analyzed data to a CSV file data_exporter = Button("Export Data to CSV") #: Button to export the summary data to a CSV file summary_exporter = Button("Export Summary to CSV") # Detailed configuration traits ------------------------------------------- #: View class to use. Modify to customize. view_klass = Any(View) #: Width of the view view_width = Int(1100) #: Height of the view view_height = Int(700) #: Width of the filter box filter_item_width = Int(400) max_names_per_column = Int(12) truncation_msg_template = Str("Table truncated at {} rows") warn_if_sel_hidden = Bool(True) hidden_selection_msg = Str #: Column names (as a list) to include in filter editor assistant filter_editor_cols = List # Implementation details -------------------------------------------------- #: Evaluate number of columns to select panel or popup column control _many_columns = Property(Bool, depends_on="model.column_list") #: Popped-up UI to control the visible columns _control_popup = Any #: Collected traitsUI editors for both the data DF and the summary DF _df_editors = Dict # HasTraits interface ----------------------------------------------------- def __init__(self, **traits): if "model" in traits and isinstance(traits["model"], pd.DataFrame): traits["model"] = DataFrameAnalyzer(source_df=traits["model"]) super(DataFrameAnalyzerView, self).__init__(**traits) if self.include_plotter: # If a plotter view was specified, its model should be in the # model's list of plot managers: if self.plotter.model not in self.model.plot_manager_list: self.model.plot_manager_list.append(self.plotter.model) def traits_view(self): """ Putting the view components together. Each component of the view is built in a separate method so it can easily be subclassed and customized. """ # Construction of view groups ----------------------------------------- data_group = self.view_data_group_builder() column_controls_group = self.view_data_control_group_builder() summary_group = self.view_summary_group_builder() summary_controls_group = self.view_summary_control_group_builder() if self.show_categorical_summary: cat_summary_group = self.view_cat_summary_group_builder() else: cat_summary_group = None plotter_group = self.view_plotter_group_builder() button_content = [ Item("data_exporter", show_label=False), Spring(), Item("summary_exporter", show_label=False) ] if self.plotter_layout == "popup": button_content += [ Spring(), Item("plotter_launcher", show_label=False) ] button_group = HGroup(*button_content) # Organization of item groups ----------------------------------------- # If both types of summary are available, display as Tabbed view: if summary_group is not None and cat_summary_group is not None: summary_container = Tabbed( HSplit( summary_controls_group, summary_group, label=self.num_summary_group_name ), cat_summary_group, ) elif cat_summary_group is not None: summary_container = cat_summary_group else: summary_container = HSplit( summary_controls_group, summary_group ) # Allow to hide all summary information: summary_container.visible_when = "_show_summary" exploration_groups = VGroup( VSplit( HSplit( column_controls_group, data_group, ), summary_container ), button_group, label=self.exploration_group_label ) if self.include_plotter and self.plotter_layout != "popup": layout = getattr(traitsui.api, self.plotter_layout) groups = layout( exploration_groups, plotter_group ) else: groups = exploration_groups view = self.view_klass( groups, resizable=True, title=self.app_title, width=self.view_width, height=self.view_height ) return view # Traits view building methods -------------------------------------------- def view_data_group_builder(self): """ Build view element for the Data display """ editor_kw = dict(show_index=True, columns=self.visible_columns, fonts=self.fonts, formats=self.formats) data_editor = DataFrameEditor(selected_row="selected_idx", multi_select=True, **editor_kw) filter_group = HGroup( Item("model.filter_exp", label="Filter", width=self.filter_item_width), Item("pop_out_filter_button", show_label=False, style="custom", tooltip="Open filter editor..."), Item("apply_filter_button", show_label=False, visible_when="not model.filter_auto_apply", style="custom", tooltip="Apply current filter"), Item("save_filter_button", show_label=False, enabled_when="model.filter_exp not in _known_expr", visible_when="filter_manager", style="custom", tooltip="Save current filter"), Item("load_filter_button", show_label=False, visible_when="filter_manager", style="custom", tooltip="Load a filter..."), Item("manage_filter_button", show_label=False, visible_when="filter_manager", style="custom", tooltip="Manage filters..."), ) truncated = ("len(model.displayed_df) < len(model.filtered_df) and " "not model.show_selected_only") more_label = "Show {} More".format(self.model.num_display_increment) display_control_group = HGroup( Item("model.show_selected_only", label="Selected rows only"), Item("truncation_msg", style="readonly", show_label=False, visible_when=truncated), Item("show_more_button", editor=ButtonEditor(label=more_label), show_label=False, visible_when=truncated), Item("show_all_button", show_label=False, visible_when=truncated), ) data_group = VGroup( make_window_title_group(self.data_section_title, title_size=3, include_blank_spaces=False), HGroup( Item("model.sort_by_col", label="Sort by"), Item("shuffle_button", show_label=False, visible_when="show_shuffle_button"), Spring(), filter_group ), HGroup( Item("model.displayed_df", editor=data_editor, show_label=False), ), HGroup( Item("show_column_controls", label="\u2190 Show column control", visible_when="not _many_columns"), Item("open_column_controls", show_label=False, visible_when="_many_columns"), Spring(), Item("_show_summary", label=u'\u2193 Show summary', visible_when="allow_show_summary"), Spring(), display_control_group ), show_border=True ) return data_group def view_data_control_group_builder(self, force_visible=False): """ Build view element for the Data column control. Parameters ---------- force_visible : bool Controls visibility of the created group. Don't force for the group embedded in the global view, but force it when opened as a popup. """ num_cols = 1 + len(self.model.column_list) // self.max_names_per_column column_controls_group = VGroup( make_window_title_group(self.column_list_section_title, title_size=3, include_blank_spaces=False), Item("visible_columns", show_label=False, editor=CheckListEditor(values=self.model.column_list, cols=num_cols), # The custom style allows to control a list of options rather # than having a checklist editor for a single value: style='custom'), show_border=True ) if force_visible: column_controls_group.visible_when = "" else: column_controls_group.visible_when = "show_column_controls" return column_controls_group def view_summary_group_builder(self): """ Build view element for the numerical data summary display """ editor_kw = dict(show_index=True, columns=self.visible_columns, fonts=self.fonts, formats=self.formats) summary_editor = DataFrameEditor(**editor_kw) summary_group = VGroup( make_window_title_group(self.summary_section_title, title_size=3, include_blank_spaces=False), Item("model.summary_df", editor=summary_editor, show_label=False, visible_when="len(model.summary_df) != 0"), # Workaround the fact that the Label's visible_when is buggy: # encapsulate it into a group and add the visible_when to the group HGroup( Label("No data columns with numbers were found."), visible_when="len(model.summary_df) == 0" ), HGroup( Item("show_summary_controls"), Spring(), visible_when="len(model.summary_df) != 0" ), show_border=True, ) return summary_group def view_summary_control_group_builder(self): """ Build view element for the column controls for data summary. """ summary_controls_group = VGroup( make_window_title_group(self.summary_content_section_title, title_size=3, include_blank_spaces=False), Item("model.summary_index", show_label=False), visible_when="show_summary_controls", show_border=True ) return summary_controls_group def view_cat_summary_group_builder(self): """ Build view element for the categorical data summary display. """ editor_kw = dict(show_index=True, fonts=self.fonts, formats=self.formats) summary_editor = DataFrameEditor(**editor_kw) cat_summary_group = VGroup( make_window_title_group(self.cat_summary_section_title, title_size=3, include_blank_spaces=False), Item("model.summary_categorical_df", editor=summary_editor, show_label=False, visible_when="len(model.summary_categorical_df)!=0"), # Workaround the fact that the Label's visible_when is buggy: # encapsulate it into a group and add the visible_when to the group HGroup( Label("No data columns with numbers were found."), visible_when="len(model.summary_categorical_df)==0" ), show_border=True, label=self.cat_summary_group_name ) return cat_summary_group def view_plotter_group_builder(self): """ Build view element for the plotter tool. """ plotter_group = VGroup( Item("plotter", editor=InstanceEditor(), show_label=False, style="custom"), label=self.plotting_group_label ) return plotter_group # Public interface -------------------------------------------------------- def destroy(self): """ Clean up resources. """ if self._control_popup: self._control_popup.dispose() # Traits listeners -------------------------------------------------------- def _open_column_controls_fired(self): """ Pop-up a new view on the column list control. """ if self._control_popup and self._control_popup.control: # If there is an existing window, bring it in focus: # Discussion: https://stackoverflow.com/questions/2240717/in-qt-how-do-i-make-a-window-be-the-current-window # noqa self._control_popup.control._mw.activateWindow() return # Before viewing self with a simplified view, make sure the original # view editors are collected so they can be modified when the controls # are used: if not self._df_editors: self._collect_df_editors() view = self.view_klass( self.view_data_control_group_builder(force_visible=True), buttons=[OKButton], width=600, resizable=True, title="Control visible columns" ) # WARNING: this will modify the info object the view points to! self._control_popup = self.edit_traits(view=view, kind="live") def _shuffle_button_fired(self): self.model.shuffle_filtered_df() def _apply_filter_button_fired(self): flt = self.model.filter_exp msg = f"Applying filter {flt}." logger.log(ACTION_LEVEL, msg) self.model.recompute_filtered_df() def _pop_out_filter_button_fired(self): if not self.filter_editor_cols: # if there are no included columns, then use all categorical cols df = self.model.source_df cat_df = df.select_dtypes(include=CATEGORICAL_COL_TYPES) self.filter_editor_cols = list(cat_df.columns) filter_editor = FilterExpressionEditorView( expr=self.model.filter_exp, view_klass=self.view_klass, source_df=self.model.source_df, included_cols=self.filter_editor_cols) ui = filter_editor.edit_traits(kind="livemodal") if ui.result: self.model.filter_exp = filter_editor.expr self.apply_filter_button = True def _manage_filter_button_fired(self): """ TODO: review if replacing the copy by a deepcopy or removing the copy altogether would help traits trigger listeners correctly """ msg = "Opening filter manager." logger.log(ACTION_LEVEL, msg) # Make a copy of the list of filters so the model can listen to changes # even if only a field of an existing filter is modified: filter_manager = FilterExpressionManager( known_filter_exps=copy(self.model.known_filter_exps), mode="manage", view_klass=self.view_klass ) ui = filter_manager.edit_traits(kind="livemodal") if ui.result: # FIXME: figure out why this simpler assignment doesn't trigger the # traits listener on the model when changing a FilterExpression # attribute: # self.model.known_filter_exps = filter_manager.known_filter_exps self.model.known_filter_exps = [ FilterExpression(name=e.name, expression=e.expression) for e in filter_manager.known_filter_exps ] def _load_filter_button_fired(self): filter_manager = FilterExpressionManager( known_filter_exps=self.model.known_filter_exps, mode="load", view_klass=self.view_klass ) ui = filter_manager.edit_traits(kind="livemodal") if ui.result: selection = filter_manager.selected_expression self.model.filter_exp = selection.expression def _save_filter_button_fired(self): exp = self.model.filter_exp if exp in [e.expression for e in self.model.known_filter_exps]: return expr = FilterExpression(name=exp, expression=exp) self.model.known_filter_exps.append(expr) def _show_more_button_fired(self): self.model.num_displayed_rows += self.model.num_display_increment def _show_all_button_fired(self): self.model.num_displayed_rows = -1 @on_trait_change("model:selected_data_in_plotter_updated", post_init=True) def warn_if_selection_hidden(self): """ Pop up warning msg if some of the selected rows aren't displayed. """ if not self.warn_if_sel_hidden: return if not self.model.selected_idx: return truncated = len(self.model.displayed_df) < len(self.model.filtered_df) max_displayed = self.model.displayed_df.index.max() some_selection_hidden = max(self.model.selected_idx) > max_displayed if truncated and some_selection_hidden: warning(None, self.hidden_selection_msg, "Hidden selection") @on_trait_change("visible_columns[]", post_init=True) def update_filtered_df_on_columns(self): """ Just show the columns that are set to visible. Notes ----- We are not modifying the filtered data because if we remove a column and then bring it back, the adapter breaks because it is missing data. Breakage happen when removing a column if the model is changed first, or when bring a column back if the adapter column list is changed first. """ if not self.info.initialized: return if not self._df_editors: self._collect_df_editors() # Rebuild the column list (col name, column id) for the tabular # adapter: all_visible_cols = [(col, col) for col in self.visible_columns] df = self.model.source_df cat_dtypes = self.model.categorical_dtypes summarizable_df = df.select_dtypes(exclude=cat_dtypes) summary_visible_cols = [(col, col) for col in self.visible_columns if col in summarizable_df.columns] for df_name, cols in zip(["displayed_df", "summary_df"], [all_visible_cols, summary_visible_cols]): df = getattr(self.model, df_name) index_name = df.index.name if index_name is None: index_name = '' # This grabs the corresponding _DataFrameEditor (not the editor # factory) which has access to the adapter object: editor = self._df_editors[df_name] editor.adapter.columns = [(index_name, 'index')] + cols def _collect_df_editors(self): for df_name in ["displayed_df", "summary_df"]: try: # This grabs the corresponding _DataFrameEditor (not the editor # factory) which has access to the adapter object: self._df_editors[df_name] = getattr(self.info, df_name) except Exception as e: msg = "Error trying to collect the tabular adapter: {}" logger.error(msg.format(e)) def _plotter_launcher_fired(self): """ Pop up plot manager view. Only when self.plotter_layout="popup". """ self.plotter.edit_traits(kind="livemodal") def _data_exporter_fired(self): filepath = request_csv_file(action="save as") if filepath: self.model.filtered_df.to_csv(filepath) open_file(filepath) def _summary_exporter_fired(self): filepath = request_csv_file(action="save as") if filepath: self.model.summary_df.to_csv(filepath) open_file(filepath) # Traits property getters/setters ----------------------------------------- def _get__known_expr(self): return {e.expression for e in self.model.known_filter_exps} @cached_property def _get_truncation_msg(self): num_displayed_rows = self.model.num_displayed_rows return self.truncation_msg_template.format(num_displayed_rows) @cached_property def _get__many_columns(self): # Many columns means more than 2 columns: return len(self.model.column_list) > 2 * self.max_names_per_column # Traits initialization methods ------------------------------------------- def _plotter_default(self): if self.include_plotter: if self.model.plot_manager_list: if len(self.model.plot_manager_list) > 1: num_plotters = len(self.model.plot_manager_list) msg = "Model contains {} plot manager, but only " \ "initializing the Analyzer view with the first " \ "plot manager available.".format(num_plotters) logger.warning(msg) plot_manager = self.model.plot_manager_list[0] else: plot_manager = DataFramePlotManager( data_source=self.model.filtered_df, source_analyzer=self.model, **self.plotter_kw ) view = DataFramePlotManagerView(model=plot_manager, view_klass=self.view_klass) return view def _formats_default(self): if self.display_precision < 0: return '%s' else: formats = {} float_format = '%.{}g'.format(self.display_precision) for col in self.model.source_df.columns: col_dtype = self.model.source_df.dtypes[col] if np.issubdtype(col_dtype, np.number): formats[col] = float_format else: formats[col] = '%s' return formats def _visible_columns_default(self): return self.model.column_list def _hidden_selection_msg_default(self): msg = "The displayed data is truncated and some of the selected " \ "rows isn't displayed in the data table." return msg def _summary_section_title_default(self): if len(self.model.summary_categorical_df) == 0: return "Data summary" else: return "Numerical data summary" def _fonts_default(self): return "{} {}".format(self.font_name, self.font_size)
class GeoNDGrid(Source): ''' Specification and representation of an nD-grid. GridND ''' # The name of our scalar array. scalar_name = Str('scalar') # map of coordinate labels to the indices _dim_map = {'x': 0, 'y': 1, 'z': 2} # currently active dimensions active_dims = List(Str, ['x', 'y']) # Bottom left corner x_mins = Instance(GridPoint, label='Corner 1') def _x_mins_default(self): '''Bottom left corner''' return GridPoint() # Upper right corner x_maxs = Instance(GridPoint, label='Corner 2') def _x_maxs_default(self): '''Upper right corner''' return GridPoint(x=1, y=1, z=1) # indices of the currently active dimensions dim_indices = Property(Array(int), depends_on='active_dims') @cached_property def _get_dim_indices(self): ''' Get active indices ''' return array([self._dim_map[dim_ix] for dim_ix in self.active_dims], dtype='int_') # number of currently active dimensions n_dims = Property(Int, depends_on='active_dims') @cached_property def _get_n_dims(self): '''Number of currently active dimensions''' return len(self.active_dims) # number of elements in each direction # @todo: rename to n_faces shape = Tuple(int, int, int, label='Elements') def _shape_default(self): '''Number of elements in each direction''' return (1, 0, 0) n_act_nodes = Property(Array, depends_on='shape, active_dims') @cached_property def _get_n_act_nodes(self): '''Number of active nodes respecting the active_dim''' act_idx = ones((3, ), int) shape = array(list(self.shape), dtype=int) act_idx[self.dim_indices] += shape[self.dim_indices] return act_idx # total number of nodes of the system grid n_nodes = Property(Int, depends_on='shape, active_dims') @cached_property def _get_n_nodes(self): '''Number of nodes used for the geometry approximation''' return reduce(lambda x, y: x * y, self.n_act_nodes) enum_nodes = Property(Array, depends_on='shape,active_dims') @cached_property def _get_enum_nodes(self): ''' Returns an array of element numbers respecting the grid structure (the nodes are numbered first in x-direction, then in y-direction and last in z-direction) ''' return arange(self.n_nodes).reshape(tuple(self.n_act_nodes)) grid = Property(Array, depends_on='shape,active_dims,x_mins.+,x_maxs.+') @cached_property def _get_grid(self): ''' slice(start,stop,step) with step of type 'complex' leads to that number of divisions in that direction including 'stop' (see numpy: 'mgrid') ''' slices = [ slice(x_min, x_max, complex(0, n_n)) for x_min, x_max, n_n in zip( self.x_mins, self.x_maxs, self.n_act_nodes) ] return mgrid[tuple(slices)] #------------------------------------------------------------------------- # Visualization pipelines #------------------------------------------------------------------------- # mvp_mgrid_geo = Trait(MVPolyData) # # def _mvp_mgrid_geo_default(self): # return MVPolyData(name='Mesh geomeetry', # points=self._get_points, # lines=self._get_lines, # polys=self._get_faces, # scalars=self._get_random_scalars # ) # # mvp_mgrid_labels = Trait(MVPointLabels) # # def _mvp_mgrid_labels_default(self): # return MVPointLabels(name='Mesh numbers', # points=self._get_points, # scalars=self._get_random_scalars, # vectors=self._get_points) changed = Button('Draw') @on_trait_change('changed') def redraw(self): ''' ''' self.mvp_mgrid_geo.redraw() self.mvp_mgrid_labels.redraw('label_scalars') def _get_points(self): ''' Reshape the grid into a column. ''' return c_[tuple([self.grid[i].flatten() for i in range(3)])] def _get_n_lines(self): ''' Get the number of lines. ''' act_idx = ones((3, ), int) act_idx[self.dim_indices] += self.shape[self.dim_indices] return reduce(lambda x, y: x * y, act_idx) def _get_lines(self): ''' Only return data if n_dims = 1 ''' if self.n_dims != 1: return array([], int) # # Get the list of all base nodes # tidx = ones((3, ), dtype='int_') tidx[self.dim_indices] = -1 slices = tuple([slice(0, idx) for idx in tidx]) base_node_list = self.enum_nodes[slices].flatten() # # Get the node map within the line # ijk_arr = zeros((3, 2), dtype=int) ijk_arr[self.dim_indices[0]] = [0, 1] offsets = self.enum_nodes[ijk_arr[0], ijk_arr[1], ijk_arr[2]] # # Setup and fill the array with line connectivities # n_lines = self._get_n_lines() lines = zeros((n_lines, 2), dtype='int_') for n_idx, base_node in enumerate(base_node_list): lines[n_idx, :] = offsets + base_node return lines def _get_n_faces(self): '''Return the number of faces. The number is determined by putting 1 into inactive dimensions and shape into the active dimensions. ''' act_idx = ones((3, ), int) shape = array(self.shape, dtype=int) act_idx[self.dim_indices] = shape[self.dim_indices] return reduce(lambda x, y: x * y, act_idx) def _get_faces(self): ''' Only return data of n_dims = 2. ''' if self.n_dims != 2: return array([], int) # # get the slices extracting all corner nodes with # the smallest node number within the element # tidx = ones((3, ), dtype='int_') tidx[self.dim_indices] = -1 slices = tuple([slice(0, idx) for idx in tidx]) base_node_list = self.enum_nodes[slices].flatten() # # get the node map within the face # ijk_arr = zeros((3, 4), dtype=int) ijk_arr[self.dim_indices[0]] = [0, 0, 1, 1] ijk_arr[self.dim_indices[1]] = [0, 1, 1, 0] offsets = self.enum_nodes[ijk_arr[0], ijk_arr[1], ijk_arr[2]] # # setup and fill the array with line connectivities # n_faces = self._get_n_faces() faces = zeros((n_faces, 4), dtype='int_') for n_idx, base_node in enumerate(base_node_list): faces[n_idx, :] = offsets + base_node return faces def _get_volumes(self): ''' Only return data if ndims = 3 ''' if self.n_dims != 3: return array([], int) tidx = ones((3, ), dtype='int_') tidx[self.dim_indices] = -1 slices = tuple([slice(0, idx) for idx in tidx]) en = self.enum_nodes offsets = array([ en[0, 0, 0], en[0, 1, 0], en[1, 1, 0], en[1, 0, 0], en[0, 0, 1], en[0, 1, 1], en[1, 1, 1], en[1, 0, 1] ], dtype='int_') base_node_list = self.enum_nodes[slices].flatten() n_faces = self._get_n_faces() faces = zeros((n_faces, 8), dtype='int_') for n in base_node_list: faces[n, :] = offsets + n return faces # Identifiers var = Str('dummy') idx = Int(0) def _get_random_scalars(self): return random.weibull(1, size=self.n_nodes) traits_view = View(HSplit( Group( Item('changed', show_label=False), Item('active_dims@', editor=CheckListEditor(values=['x', 'y', 'z'], cols=3)), Item('x_mins@', resizable=False), Item('x_maxs@'), Item('shape@'), ), ), resizable=True)
class HeadViewController(HasTraits): """ Set head views for Anterior-Left-Superior coordinate system Parameters ---------- system : 'RAS' | 'ALS' | 'ARI' Coordinate system described as initials for directions associated with the x, y, and z axes. Relevant terms are: Anterior, Right, Left, Superior, Inferior. """ system = Enum("RAS", "ALS", "ARI", desc="Coordinate system: directions of " "the x, y, and z axis.") right = Button() front = Button() left = Button() top = Button() scale = Float(0.16) scene = Instance(MlabSceneModel) view = View( VGrid('0', 'top', '0', Item('scale', label='Scale', show_label=True), 'right', 'front', 'left', show_labels=False, columns=4)) @on_trait_change('scene.activated') def _init_view(self): self.scene.parallel_projection = True # apparently scene,activated happens several times if self.scene.renderer: self.sync_trait('scale', self.scene.camera, 'parallel_scale') # and apparently this does not happen by default: self.on_trait_change(self.scene.render, 'scale') @on_trait_change('top,left,right,front') def on_set_view(self, view, _): if self.scene is None: return system = self.system kwargs = None if system == 'ALS': if view == 'front': kwargs = dict(azimuth=0, elevation=90, roll=-90) elif view == 'left': kwargs = dict(azimuth=90, elevation=90, roll=180) elif view == 'right': kwargs = dict(azimuth=-90, elevation=90, roll=0) elif view == 'top': kwargs = dict(azimuth=0, elevation=0, roll=-90) elif system == 'RAS': if view == 'front': kwargs = dict(azimuth=90, elevation=90, roll=180) elif view == 'left': kwargs = dict(azimuth=180, elevation=90, roll=90) elif view == 'right': kwargs = dict(azimuth=0, elevation=90, roll=270) elif view == 'top': kwargs = dict(azimuth=90, elevation=0, roll=180) elif system == 'ARI': if view == 'front': kwargs = dict(azimuth=0, elevation=90, roll=90) elif view == 'left': kwargs = dict(azimuth=-90, elevation=90, roll=180) elif view == 'right': kwargs = dict(azimuth=90, elevation=90, roll=0) elif view == 'top': kwargs = dict(azimuth=0, elevation=180, roll=90) else: raise ValueError("Invalid system: %r" % system) if kwargs is None: raise ValueError("Invalid view: %r" % view) self.scene.mlab.view(distance=None, reset_roll=True, figure=self.scene.mayavi_scene, **kwargs)
class PossionDemo(HasTraits): left_mask = Instance(ImageMaskDrawer) right_mask = Instance(ImageMaskDrawer) left_file = File(path.join(FOLDER, "vinci_src.png")) right_file = File(path.join(FOLDER, "vinci_target.png")) figure = Instance(Figure, ()) load_button = Button("Load") clear_button = Button("Clear") mix_button = Button("Mix") view = View( Item("left_file"), Item("right_file"), VGroup( HGroup( "load_button", "clear_button", "mix_button", show_labels=False ), Group( Item("figure", editor=MPLFigureEditor(toolbar=False)), show_labels=False, orientation='horizontal' ) ), width=800, height=600, resizable=True, title="Possion Demo") def __init__(self, **kw): super(PossionDemo, self).__init__(**kw) self.left_axe = self.figure.add_subplot(121) self.right_axe = self.figure.add_subplot(122) def load_images(self): self.left_pic = cv2.imread(self.left_file, 1) self.right_pic = cv2.imread(self.right_file, 1) shape = [max(v1, v2) for v1, v2 in zip(self.left_pic.shape[:2], self.right_pic.shape[:2])] self.left_img = self.left_axe.imshow(self.left_pic[::-1, :, ::-1], origin="lower") self.left_axe.axis("off") self.left_mask = ImageMaskDrawer(self.left_axe, self.left_img, mask_shape=shape) self.right_img = self.right_axe.imshow(self.right_pic[::-1, :, ::-1], origin="lower") self.right_axe.axis("off") self.right_mask = ImageMaskDrawer(self.right_axe, self.right_img, mask_shape=shape) self.left_mask.on_trait_change(self.mask_changed, "drawed") self.right_mask.on_trait_change(self.mask_changed, "drawed") self.figure.canvas.draw_idle() def mask_changed(self, obj, name, new): if obj is self.left_mask: src, target = self.left_mask, self.right_mask else: src, target = self.right_mask, self.left_mask target.array.fill(0) target.array[:, :] = src.array[:, :] target.update() self.figure.canvas.draw() def _load_button_fired(self): self.load_images() def _mix_button_fired(self): lh, lw, _ = self.left_pic.shape rh, rw, _ = self.right_pic.shape left_mask = self.left_mask.array[-lh:, :lw, -1] if np.all(left_mask==0): return dx, dy = self.right_mask.get_mask_offset() result = possion_mix(self.left_pic, self.right_pic, left_mask, (dx, rh - lh - dy)) self.right_img.set_data(result[::-1, :, ::-1]) self.right_mask.mask_img.set_visible(False) self.figure.canvas.draw() def _clear_button_fired(self): self.left_mask.array.fill(0) self.right_mask.array.fill(0) self.left_mask.update() self.right_mask.update() self.figure.canvas.draw()
class Spylot(HasTraits): """ This class represents the spylot application state. """ defaultlines = 'galaxy' #can be 'galaxy' or 'stellar' or None """ This class attribute sets the default line lists to use - 'galaxy' or 'stellar' """ plot = Instance(Plot) histplot = Bool(True) majorlineeditor = LineListEditor minorlineeditor = LineListEditor linenamelist = Property(Tuple) showmajorlines = Bool(True) showminorlines = Bool(False) labels = List(DataLabel) showlabels = Bool(False) editmajor = Button('Edit Major') editminor = Button('Edit Minor') specs = List(Instance(spec.Spectrum)) currspeci = Int currspecip1 = Property(depends_on='currspeci') lowerspecip1 = Property(depends_on='currspeci') upperspecip1 = Property(depends_on='currspeci') currspec = Property lastspec = Instance(spec.Spectrum) z = Float lowerz = Float(0.0) upperz = Float(1.0) zviewtype = Enum('redshift', 'velocity') zview = Property(Float, depends_on='z,zviewtype') uzview = Property(Float, depends_on='upperz,zviewtype') lzview = Property(Float, depends_on='lowerz,zviewtype') coarserz = Button('Coarser') finerz = Button('Finer') _zql, _zqh = min(spec.Spectrum._zqmap.keys()), max( spec.Spectrum._zqmap.keys()) zqual = Range(_zql, _zqh, -1) spechanged = Event specleft = Button('<') specright = Button('>') scaleerr = Bool(False) scaleerrfraclow = Range(0.0, 1.0, 1.0) scaleerrfrachigh = Float(1.0) fluxformat = Button('Flux Line Format') errformat = Button('Error Line Format') showcoords = Bool(False) showgrid = Bool(True) dosmoothing = Bool(False) smoothing = Float(3) contsub = Button('Fit Continuum...') contclear = Button('Clear Continuum') showcont = Bool(False) contformat = Button('Continuum Line Format') _titlestr = Str('Spectrum 0/0') _oldunit = Str('') maintool = Tuple(Instance(Interactor), Instance(AbstractOverlay)) featureselmode = Enum([ 'No Selection', 'Click Select', 'Range Select', 'Base Select', 'Click Delete' ]) editfeatures = Button('Features...') showfeatures = Bool(True) featurelocsmooth = Float(None) featurelocsize = Int(200) featurelist = List(Instance(spec.SpectralFeature)) selectedfeatureindex = Int deletefeature = Button('Delete') idfeature = Button('Identify') recalcfeature = Button('Recalculate') clearfeatures = Button('Clear') delcurrspec = Button('Delete Current') saveloadfile = File(filter=['*.specs']) savespeclist = Button('Save Spectra') loadspeclist = Button('Load Spectra') loadaddfile = File(filter=['*.fits']) loadaddspec = Button('Add Spectrum') loadaddspectype = Enum('wcs', 'deimos', 'astropysics') titlegroup = HGroup( Item('specleft', show_label=False, enabled_when='currspeci>0'), spring, Label('Spectrum #', height=0.5), Item('currspecip1', show_label=False, editor=RangeEditor(low_name='lowerspecip1', high_name='upperspecip1', mode='spinner')), Item('_titlestr', style='readonly', show_label=False), spring, Item('specright', show_label=False, enabled_when='currspeci<(len(specs)-1)')) speclistgroup = HGroup( Label('Spectrum List:'), spring, Item('delcurrspec', show_label=False, enabled_when='len(specs)>1'), Item('saveloadfile', show_label=False), Item('savespeclist', show_label=False), Item('loadspeclist', show_label=False, enabled_when='os.path.exists(saveloadfile)'), Item('loadaddfile', show_label=False), Item('loadaddspec', show_label=False, enabled_when='os.path.exists(saveloadfile)'), Item('loadaddspectype', show_label=False), spring) plotformatgroup = HGroup( spring, Item('fluxformat', show_label=False), Item('errformat', show_label=False), Item('scaleerr', label='Scale Error?'), Item('scaleerrfraclow', label='Lower', enabled_when='scaleerr', editor=TextEditor(evaluate=float)), Item('scaleerrfrachigh', label='Upper', enabled_when='scaleerr'), Item('showgrid', label='Grid?'), Item('showcoords', label='Coords?'), spring) featuregroup = HGroup(spring, Item('showmajorlines', label='Show major?'), Item('editmajor', show_label=False), Item('showlabels', label='Labels?'), Item('showminorlines', label='Show minor?'), Item('editminor', show_label=False), spring, Item('editfeatures', show_label=False), Item('featureselmode', show_label=False), spring) continuumgroup = HGroup( spring, Item('contsub', show_label=False), Item('contclear', show_label=False), Item('showcont', label='Continuum line?'), Item('contformat', show_label=False), Item('dosmoothing', label='Smooth?'), Item('smoothing', show_label=False, enabled_when='dosmoothing'), spring) zgroup = VGroup( HGroup( Item('zviewtype', show_label=False), Item('zview', editor=RangeEditor(low_name='lzview', high_name='uzview', format='%5.4g ', mode='slider'), show_label=False, springy=True)), HGroup( Item('lzview', show_label=False), Item('coarserz', show_label=False), spring, Item('zqual', style='custom', label='Z quality', editor=RangeEditor(cols=_zqh - _zql + 1, low=_zql, high=_zqh)), spring, Item('finerz', show_label=False), Item('uzview', show_label=False))) features_view = View(VGroup( HGroup(Item('showfeatures', label='Show'), Item('featurelocsmooth', label='Locator Smoothing'), Item('featurelocsize', label='Locator Window size')), Item('featurelist', editor=TabularEditor(adapter=FeaturesAdapter(), selected_row='selectedfeatureindex'), show_label=False), HGroup( Item( 'deletefeature', show_label=False, enabled_when='len(featurelist)>0 and selectedfeatureindex>=0'), Item( 'idfeature', show_label=False, enabled_when='len(featurelist)>0 and selectedfeatureindex>=0'), Item( 'recalcfeature', show_label=False, enabled_when='len(featurelist)>0 and selectedfeatureindex>=0'), Item('clearfeatures', show_label=False, enabled_when='len(featurelist)>0'))), resizable=True, title='Spylot Features') traits_view = View(VGroup( Include('titlegroup'), Include('speclistgroup'), Include('plotformatgroup'), Include('featuregroup'), Include('continuumgroup'), Item('plot', editor=ComponentEditor(), show_label=False, width=768), Include('zgroup')), resizable=True, title='Spectrum Plotter', handler=SpylotHandler(), key_bindings=spylotkeybindings) def __init__(self, specs, **traits): """ :param specs: The spectra/um to be analyzed as a sequence or a single object. :type specs: :class:`astropysics.spec.Spectrum` kwargs are passed in as additional traits. """ #pd = ArrayPlotData(x=[1],x0=[1],flux=[1],err=[1]) #reset by spechanged event pd = ArrayPlotData(x=[1], flux=[1], err=[1]) #reset by spechanged event pd.set_data('majorx', [1, 1]) #reset by majorlines change pd.set_data('majory', [0, 1]) #reset by majorlines change pd.set_data('minorx', [1, 1]) #reset by minorlines change pd.set_data('minory', [0, 1]) #reset by minorlines change pd.set_data('continuum', [0, 0]) #reset self.plot = plot = Plot(pd, resizeable='hv') ploi = plot.draw_order.index('plot') plot.draw_order[ploi:ploi] = [ 'continuum', 'err', 'flux', 'annotations' ] plot.plot(('x', 'flux'), name='flux', type='line', line_style='solid', color='blue', draw_layer='flux', unified_draw=True) plot.plot(('x', 'err'), name='err', type='line', line_style='dash', color='green', draw_layer='err', unified_draw=True) topmapper = LinearMapper(range=DataRange1D()) plot.x_mapper.range.on_trait_change(self._update_upperaxis_range, 'updated') plot.x_mapper.on_trait_change(self._update_upperaxis_screen, 'updated') self.upperaxis = PlotAxis(plot, orientation='top', mapper=topmapper) plot.overlays.append(self.upperaxis) self.errmapperfixed = plot.plots['err'][0].value_mapper self.errmapperscaled = LinearMapper(range=DataRange1D(high=1.0, low=0)) plot.x_mapper.on_trait_change(self._update_errmapper_screen, 'updated') plot.padding_top = 30 #default is a bit much plot.padding_left = 70 #default is a bit too little majorlineplot = plot.plot(('majorx', 'majory'), name='majorlineplot', type='line', line_style='dash', color='red')[0] majorlineplot.set(draw_layer='annotations', unified_draw=True) majorlineplot.value_mapper = LinearMapper( range=DataRange1D(high=0.9, low=0.1)) majorlineplot.visible = self.showmajorlines del plot.x_mapper.range.sources[ -1] #remove the line plot from the x_mapper sources so scaling is only on the spectrum self.majorlineeditor = LineListEditor(lineplot=majorlineplot) minorlineplot = plot.plot(('minorx', 'minory'), name='minorlineplot', type='line', line_style='dot', color='red')[0] minorlineplot.set(draw_layer='annotations', unified_draw=True) minorlineplot.value_mapper = LinearMapper( range=DataRange1D(high=0.9, low=0.1)) minorlineplot.visible = self.showminorlines del plot.x_mapper.range.sources[ -1] #remove the line plot from the x_mapper sources so scaling is only on the spectrum self.minorlineeditor = LineListEditor(lineplot=minorlineplot) self.contline = plot.plot(('x', 'continuum'), name='continuum', type='line', line_style='solid', color='black')[0] self.contline.set(draw_layer='continuum', unified_draw=True) self.contline.visible = self.showcont # idat = ArrayDataSource((0.0,1.0)) # vdat = ArrayDataSource((0.0,0.0)) # self.zeroline = LinePlot(index=idat,value=vdat,line_style='solid',color='black') # self.zeroline.index_mapper = LinearMapper(range=DataRange1D(high=0.9,low=0.1)) # self.zeroline.value_mapper = self.plot.y_mapper # self.zeroline.visible = self.showcont # self.plot.add(self.zeroline) if Spylot.defaultlines: defaultlines = _get_default_lines(Spylot.defaultlines) self.majorlineeditor.candidates = defaultlines[0] self.majorlineeditor.selectednames = defaultlines[1] self.minorlineeditor.candidates = defaultlines[0] self.minorlineeditor.selectednames = defaultlines[2] plot.tools.append(PanTool(plot)) plot.tools.append(ZoomTool(plot)) plot.overlays.append(plot.tools[-1]) self.coordtext = TextBoxOverlay(component=plot, align='ul') plot.overlays.append(self.coordtext) plot.tools.append(MouseMoveReporter(overlay=self.coordtext, plot=plot)) self.coordtext.visible = self.showcoords self.linehighlighter = lho = LineHighlighterOverlay(component=plot) lho.visible = self.showfeatures plot.overlays.append(lho) if specs is None: specs = [] elif isinstance(specs, spec.Spectrum): specs = [specs] self.specs = specs self.spechanged = True self.on_trait_change(self._majorlines_changed, 'majorlineeditor.selectedobjs') self.on_trait_change(self._minorlines_changed, 'minorlineeditor.selectedobjs') super(Spylot, self).__init__(**traits) def __del__(self): try: self.currspec._features = list(self.currspec._features) except (AttributeError, IndexError), e: pass
class DofCellGrid(SDomain): ''' Get an array with element Dof numbers ''' implements(ICellArraySource) cell_grid = Instance(CellGrid) get_cell_point_X_arr = DelegatesTo('cell_grid') get_cell_mvpoints = DelegatesTo('cell_grid') cell_node_map = DelegatesTo('cell_grid') get_cell_offset = DelegatesTo('cell_grid') # offset of dof within domain list # dof_offset = Int(0) # number of degrees of freedom in a single node # n_nodal_dofs = Int(3) #------------------------------------------------------------------------- # Generation methods for geometry and index maps #------------------------------------------------------------------------- n_dofs = Property(depends_on='cell_grid.shape,n_nodal_dofs,dof_offset') def _get_n_dofs(self): ''' Get the total number of DOFs ''' unique_cell_nodes = unique(self.cell_node_map.flatten()) n_unique_nodes = len(unique_cell_nodes) return n_unique_nodes * self.n_nodal_dofs dofs = Property(depends_on='cell_grid.shape,n_nodal_dofs,dof_offset') @cached_property def _get_dofs(self): ''' Construct the point grid underlying the mesh grid structure. ''' cell_node_map = self.cell_node_map unique_cell_nodes = unique(cell_node_map.flatten()) n_unique_nodes = len(unique_cell_nodes) n_nodal_dofs = self.n_nodal_dofs n_nodes = self.cell_grid.point_grid_size node_dof_array = repeat(-1, n_nodes * n_nodal_dofs).reshape( n_nodes, n_nodal_dofs) # Enumerate the DOFs in the mesh. The result is an array with n_nodes rows # and n_nodal_dofs columns # # A = array( [[ 0, 1 ], # [ 2, 3 ], # [ 4, 5 ]] ); # node_dof_array[ index_exp[ unique_cell_nodes ] ] = \ arange( n_unique_nodes * n_nodal_dofs).reshape(n_unique_nodes, n_nodal_dofs) # add the dof_offset before returning the array # node_dof_array += self.dof_offset return node_dof_array def _get_doffed_nodes(self): ''' Get the indices of nodes containing DOFs. ''' cell_node_map = self.cell_node_map unique_cell_nodes = unique(cell_node_map.flatten()) n_nodes = self.cell_grid.point_grid_size doffed_nodes = repeat(-1, n_nodes) doffed_nodes[index_exp[unique_cell_nodes]] = 1 return where(doffed_nodes > 0)[0] #----------------------------------------------------------------- # Elementwise-representation of dofs #----------------------------------------------------------------- cell_dof_map = Property(depends_on='cell_grid.shape,n_nodal_dofs') def _get_cell_dof_map(self): return self.dofs[index_exp[self.cell_grid.cell_node_map]] cell_grid_dof_map = Property(depends_on='cell_grid.shape,n_nodal_dofs') def _get_cell_grid_dof_map(self): return self.dofs[index_exp[self.cell_grid.cell_grid_node_map]] def get_cell_dofs(self, cell_idx): return self.cell_dof_map[cell_idx] elem_dof_map = Property(depends_on='cell_grid.shape,n_nodal_dofs') @cached_property def _get_elem_dof_map(self): el_dof_map = copy(self.cell_dof_map) tot_shape = el_dof_map.shape[0] n_entries = el_dof_map.shape[1] * el_dof_map.shape[2] elem_dof_map = el_dof_map.reshape(tot_shape, n_entries) return elem_dof_map def __getitem__(self, idx): '''High level access and slicing to the cells within the grid. The return value is a tuple with 1. array of cell indices 2. array of nodes for each element 3. array of coordinates for each node. ''' dgs = DofGridSlice(dof_grid=self, grid_slice=idx) return dgs #----------------------------------------------------------------- # Spatial queries for dofs #----------------------------------------------------------------- def _get_dofs_for_nodes(self, nodes): '''Get the dof numbers and associated coordinates given the array of nodes. ''' doffed_nodes = self._get_doffed_nodes() # print 'nodes' # print nodes # print 'doffed_nodes' # print doffed_nodes intersect_nodes = intersect1d(nodes, doffed_nodes, assume_unique=False) return (self.dofs[index_exp[intersect_nodes]], self.cell_grid.point_X_arr[index_exp[intersect_nodes]]) def get_boundary_dofs(self): '''Get the boundary dofs and the associated coordinates ''' nodes = [ self.cell_grid.point_idx_grid[s] for s in self.cell_grid.boundary_slices ] dofs, coords = [], [] for n in nodes: d, c = self._get_dofs_for_nodes(n) dofs.append(d) coords.append(c) return (vstack(dofs), vstack(coords)) def get_all_dofs(self): nodes = self.cell_grid.point_idx_grid[...] return self._get_dofs_for_nodes(nodes) def get_left_dofs(self): nodes = self.cell_grid.point_idx_grid[0, ...] return self._get_dofs_for_nodes(nodes) def get_right_dofs(self): nodes = self.cell_grid.point_idx_grid[-1, ...] return self._get_dofs_for_nodes(nodes) def get_top_dofs(self): nodes = self.cell_grid.point_idx_grid[:, -1, ...] return self._get_dofs_for_nodes(nodes) def get_bottom_dofs(self): nodes = self.cell_grid.point_idx_grid[:, 0, ...] return self._get_dofs_for_nodes(nodes) def get_front_dofs(self): nodes = self.cell_grid.point_idx_grid[:, :, -1] return self._get_dofs_for_nodes(nodes) def get_back_dofs(self): nodes = self.cell_grid.point_idx_grid[:, :, 0] return self._get_dofs_for_nodes(nodes) def get_bottom_left_dofs(self): nodes = self.cell_grid.point_idx_grid[0, 0, ...] return self._get_dofs_for_nodes(nodes) def get_bottom_front_dofs(self): nodes = self.cell_grid.point_idx_grid[:, 0, -1] return self._get_dofs_for_nodes(nodes) def get_bottom_back_dofs(self): nodes = self.cell_grid.point_idx_grid[:, 0, 0] return self._get_dofs_for_nodes(nodes) def get_top_left_dofs(self): nodes = self.cell_grid.point_idx_grid[0, -1, ...] return self._get_dofs_for_nodes(nodes) def get_bottom_right_dofs(self): nodes = self.cell_grid.point_idx_grid[-1, 0, ...] return self._get_dofs_for_nodes(nodes) def get_top_right_dofs(self): nodes = self.cell_grid.point_idx_grid[-1, -1, ...] return self._get_dofs_for_nodes(nodes) def get_bottom_middle_dofs(self): if self.cell_grid.point_idx_grid.shape[0] % 2 == 1: slice_middle_x = self.cell_grid.point_idx_grid.shape[0] / 2 nodes = self.cell_grid.point_idx_grid[slice_middle_x, 0, ...] return self._get_dofs_for_nodes(nodes) else: print 'Error in get_bottom_middle_dofs:'\ ' the method is only defined for an odd number of dofs in x-direction' def get_top_middle_dofs(self): if self.cell_grid.point_idx_grid.shape[0] % 2 == 1: slice_middle_x = self.cell_grid.point_idx_grid.shape[0] / 2 nodes = self.cell_grid.point_idx_grid[slice_middle_x, -1, ...] return self._get_dofs_for_nodes(nodes) else: print 'Error in get_top_middle_dofs:'\ ' the method is only defined for an odd number of dofs in x-direction' def get_left_middle_dofs(self): if self.cell_grid.point_idx_grid.shape[1] % 2 == 1: slice_middle_y = self.cell_grid.point_idx_grid.shape[1] / 2 nodes = self.cell_grid.point_idx_grid[0, slice_middle_y, ...] return self._get_dofs_for_nodes(nodes) else: print 'Error in get_left_middle_dofs:'\ ' the method is only defined for an odd number of dofs in y-direction' def get_right_middle_dofs(self): if self.cell_grid.point_idx_grid.shape[1] % 2 == 1: slice_middle_y = self.cell_grid.point_idx_grid.shape[1] / 2 nodes = self.cell_grid.point_idx_grid[-1, slice_middle_y, ...] return self._get_dofs_for_nodes(nodes) else: print 'Error in get_right_middle_dofs:'\ ' the method is only defined for an odd number of dofs in y-direction' def get_left_front_bottom_dof(self): nodes = self.cell_grid.point_idx_grid[0, 0, -1] return self._get_dofs_for_nodes(nodes) def get_left_front_middle_dof(self): if self.cell_grid.point_idx_grid.shape[1] % 2 == 1: slice_middle_y = self.cell_grid.point_idx_grid.shape[1] / 2 nodes = self.cell_grid.point_idx_grid[0, slice_middle_y, -1] return self._get_dofs_for_nodes(nodes) else: print 'Error in get_left_middle_front_dof:'\ ' the method is only defined for an odd number of dofs in y-direction' #----------------------------------------------------------------- # Visualization related methods #----------------------------------------------------------------- refresh_button = Button('Draw') @on_trait_change('refresh_button') def redraw(self): '''Redraw the point grid. ''' self.cell_grid.redraw() dof_cell_array = Button def _dof_cell_array_fired(self): cell_array = self.cell_grid.cell_node_map self.show_array = CellArray(data=cell_array, cell_view=DofCellView(cell_grid=self)) self.show_array.current_row = 0 self.show_array.configure_traits(kind='live') #------------------------------------------------------------------ # UI - related methods #------------------------------------------------------------------ traits_view = View(Item('n_nodal_dofs'), Item('dof_offset'), Item('cell_grid@', show_label=False), Item('refresh_button', show_label=False), Item('dof_cell_array', show_label=False), resizable=True, scrollable=True, height=0.5, width=0.5)
class ExperimentFactory(Loggable, ConsumerMixin): db = Any run_factory = Instance(AutomatedRunFactory) queue_factory = Instance(ExperimentQueueFactory) # templates = DelegatesTo('run_factory') # template = DelegatesTo('run_factory') add_button = Button('Add') clear_button = Button('Clear') save_button = Button('Save') edit_mode_button = Button('Edit') edit_enabled = DelegatesTo('run_factory') auto_increment_id = Bool(False) auto_increment_position = Bool(False) queue = Instance(ExperimentQueue, ()) # ok_run = Property(depends_on='_mass_spectrometer, _extract_device') ok_add = Property( depends_on='_mass_spectrometer, _extract_device, _labnumber, _username' ) _username = String _mass_spectrometer = String extract_device = String _labnumber = String selected_positions = List default_mass_spectrometer = Str # help_label = String('Select Irradiation/Level or Project') #=========================================================================== # permisions #=========================================================================== # max_allowable_runs = Int(10000) # can_edit_scripts = Bool(True) def __init__(self, *args, **kw): super(ExperimentFactory, self).__init__(*args, **kw) self.setup_consumer(self._add_run) def destroy(self): self._should_consume = False def set_selected_runs(self, runs): self.run_factory.set_selected_runs(runs) def _add_run(self, *args, **kw): egs = list(set([ai.extract_group for ai in self.queue.automated_runs])) eg = max(egs) if egs else 0 positions = [str(pi.positions[0]) for pi in self.selected_positions] load_name = self.queue_factory.load_name new_runs, freq = self.run_factory.new_runs( positions=positions, auto_increment_position=self.auto_increment_position, auto_increment_id=self.auto_increment_id, extract_group_cnt=eg) # if self.run_factory.check_run_addition(new_runs, load_name): #if self.run_factory.check_run_addition(new_runs, load_name): q = self.queue if q.selected: idx = q.automated_runs.index(q.selected[-1]) else: idx = len(q.automated_runs) - 1 self.queue.add_runs(new_runs, freq) idx += len(new_runs) with self.run_factory.update_selected_ctx(): self.queue.select_run_idx(idx) #add() #invoke_in_main_thread(add) #=============================================================================== # handlers #=============================================================================== def _clear_button_fired(self): self.queue.clear_frequency_runs() def _add_button_fired(self): """ only allow add button to be fired every 0.5s use consumermixin.add_consumable instead of frequency limiting """ self.add_consumable(1) def _edit_mode_button_fired(self): self.run_factory.edit_mode = not self.run_factory.edit_mode #@on_trait_change('run_factory:clear_end_after') #def _clear_end_after(self, new): # print 'enadfas', new def _update_end_after(self, new): if new: for ai in self.queue.automated_runs: ai.end_after = False self.run_factory.set_end_after(new) @on_trait_change('''queue_factory:[mass_spectrometer, extract_device, delay_+, tray, username, load_name]''') def _update_queue(self, name, new): if name == 'mass_spectrometer': self._mass_spectrometer = new self.run_factory.set_mass_spectrometer(new) elif name == 'extract_device': self._set_extract_device(new) elif name == 'username': self._username = new # self.queue.username = new if self.queue: self.queue.trait_set(**{name: new}) self.queue.changed = True #=============================================================================== # private #=============================================================================== def _set_extract_device(self, ed): self.extract_device = ed self.run_factory = self._run_factory_factory() # self.run_factory.update_templates_needed = True self.run_factory.load_templates() self.run_factory.remote_patterns = self._get_patterns(ed) self.run_factory.load_patterns() if self.queue: self.queue.set_extract_device(ed) self.queue.username = self._username self.queue.mass_spectrometer = self._mass_spectrometer def _get_patterns(self, ed): ps = [] service_name = convert_extract_device(ed) #service_name = ed.replace(' ', '_').lower() man = self.application.get_service(ILaserManager, 'name=="{}"'.format(service_name)) if man: ps = man.get_pattern_names() else: self.debug('No remote patterns. {} ({}) not available'.format( ed, service_name)) return ps #=============================================================================== # property get/set #=============================================================================== def _get_ok_add(self): ''' tol should be a user permission ''' return self._username and \ not self._mass_spectrometer in ('', 'Spectrometer', LINE_STR) and \ self._labnumber #=============================================================================== # #=============================================================================== def _run_factory_factory(self): if self.extract_device == 'Fusions UV': klass = UVAutomatedRunFactory else: klass = AutomatedRunFactory rf = klass(db=self.db, application=self.application, extract_device=self.extract_device, mass_spectrometer=self.default_mass_spectrometer) rf.load_truncations() rf.on_trait_change(lambda x: self.trait_set(_labnumber=x), 'labnumber') rf.on_trait_change(self._update_end_after, 'end_after') return rf # def _can_edit_scripts_changed(self): # self.run_factory.can_edit = self.can_edit_scripts #=============================================================================== # defaults #=============================================================================== def _run_factory_default(self): return self._run_factory_factory() def _queue_factory_default(self): eq = ExperimentQueueFactory(db=self.db) return eq def _db_changed(self): self.queue_factory.db = self.db self.run_factory.db = self.db def _default_mass_spectrometer_changed(self): self.run_factory.set_mass_spectrometer(self.default_mass_spectrometer) self.queue_factory.mass_spectrometer = self.default_mass_spectrometer self._mass_spectrometer = self.default_mass_spectrometer
class SetupPane(TraitsTaskPane): """A TraitsTaskPane containing the factory selection and new object configuration editors.""" # ------------------- # Required Attributes # ------------------- system_state = Instance(SystemState, allow_none=False) # ------------------ # Regular Attributes # ------------------ #: An internal identifier for this pane id = 'force_wfmanager.setup_pane' #: Name displayed as the title of this pane name = 'Setup Pane' #: Enables or disables the object views displayed ui_enabled = Bool(True) # ------------------ # Derived Attributes # ------------------ #: The model from selected_view. #: Listens to: :attr:`models.workflow_tree.selected_view #: <force_wfmanager.ui.setup.workflow_tree.WorkflowTree.selected_view>` selected_model = Instance(BaseModel) #: An error message for the entire workflow error_message = Str() # -------------------- # Button Attributes # -------------------- #: A Button which calls add_new_entity when pressed. add_new_entity_btn = Button() #: A Button which calls remove_entity when pressed. remove_entity_btn = Button() # ---------------- # Properties # ---------------- #: The string displayed on the 'add new entity' button. add_new_entity_label = Property( Str, depends_on='system_state:selected_factory_name' ) #: A Boolean indicating whether the currently selected modelview is #: intended to be editable by the user. This is required to avoid #: displaying a default view when a model does not have a View defined for #: it. If a modelview has a View defining how it is represented in the UI #: then this is used. selected_view_editable = Property( Bool, depends_on='system_state:selected_view' ) #: A Boolean indicating whether the currently selected view is the main #: view containing the FORCE logo main_view_visible = Property( Bool, depends_on='system_state:selected_factory_name' ) #: A Boolean indicating whether the currently selected view represents #: either a KPI or Parameter. mco_view_visible = Property( Bool, depends_on='system_state:[selected_view,selected_factory_name]' ) #: A Boolean indicating whether the currently selected view represents #: a factory view. factory_view_visible = Property( Bool, depends_on='system_state:selected_factory_name' ) #: A Boolean indicating whether the currently selected view represents #: an instance view. instance_view_visible = Property( Bool, depends_on='system_state:selected_factory_name' ) #: A Boolean indicating whether the entity creator view should be displayed entity_creator_visible = Property( Bool, depends_on='system_state:entity_creator' ) #: A panel displaying extra information about the workflow: Available #: Plugins, non-KPI variables, current filenames and any error messages. #: Displayed for factories which have a lot of empty screen space. current_info = Property( Instance(WorkflowInfo), depends_on='system_state:[selected_factory_name,selected_view],' 'task.current_file' ) #: Determines if the add button should be enabled/visible. # KPI and Execution Layers can always be added, but other workflow items # need a specific factory to be selected. add_button_enabled = Property( Bool, depends_on='system_state:[selected_factory_name,' 'entity_creator.model]' ) add_button_visible = Property( Bool, depends_on='system_state:selected_factory_name' ) #: Determines if the remove button should be visible. remove_button_visible = Property( Bool, depends_on='system_state:selected_factory_name' ) # ------------------- # View # ------------------- #: The view when editing an existing instance within the workflow tree def default_traits_view(self): """ Sets up a TraitsUI view which displays the details (parameters etc.) of the currently selected view. This varies depending on what type of view is selected.""" view = View( VGroup( # Main View with FORCE Logo VGroup( UItem( "current_info", editor=InstanceEditor(), style="custom", ), visible_when="main_view_visible", enabled_when="ui_enabled" ), # MCO Parameter and KPI views VGroup( UItem( "object.system_state.selected_view", editor=InstanceEditor(), style="custom", ), visible_when="mco_view_visible", enabled_when="ui_enabled" ), # Process Tree Views HGroup( # Instance View VGroup( UItem( "object.system_state.selected_view", editor=InstanceEditor(), style="custom", visible_when="selected_view_editable" ), UItem( "selected_model", editor=InstanceEditor(), style="custom", visible_when="selected_model is not None" ), # Remove Buttons HGroup( UItem('remove_entity_btn', label='Delete'), ), label="Item Details", visible_when="instance_view_visible", show_border=True, ), # Factory View VGroup( HGroup( UItem( "object.system_state.entity_creator", editor=InstanceEditor(), style="custom", visible_when="entity_creator_visible", width=825 ), springy=True, ), HGroup( UItem( 'add_new_entity_btn', editor=ButtonEditor( label_value='add_new_entity_label' ), enabled_when='add_button_enabled', visible_when="add_button_visible", springy=True ), # Remove Buttons UItem( 'remove_entity_btn', label='Delete Layer', visible_when='remove_button_visible' ), label="New Item Details", visible_when="factory_view_visible", show_border=True, ), ), enabled_when="ui_enabled", ), ), scrollable=True, width=500, ) return view # ------------------- # Listeners # ------------------- @cached_property def _get_selected_view_editable(self): """ Determines if the selected modelview in the WorkflowTree has a default or non-default view associated. A default view should not be editable by the user, a non-default one should be. Parameters ---------- self.selected_view.trait_views(): List of View The list of Views associated with self.selected_view. The default view is not included in this list. Returns ------- Bool Returns True if selected_view has a User Editable/Non-Default View, False if it only has a default View or no modelview is currently selected """ if self.system_state.selected_view is None: return False elif len(self.system_state.selected_view.trait_views()) == 0: return False return True @cached_property def _get_main_view_visible(self): return self.system_state.selected_factory_name == 'Workflow' @cached_property def _get_mco_view_visible(self): if self.system_state.selected_view is not None: mco_factories = ['MCO Parameters', 'MCO KPIs'] return ( self.system_state.selected_factory_name in mco_factories ) return False @cached_property def _get_factory_view_visible(self): factories = ['None', 'Workflow', 'MCO Parameters', 'MCO KPIs'] return ( self.system_state.selected_factory_name not in factories ) @cached_property def _get_instance_view_visible(self): return self.system_state.selected_factory_name == 'None' @cached_property def _get_entity_creator_visible(self): return self.system_state.entity_creator is not None @cached_property def _get_add_button_enabled(self): """ Determines if the add button in the UI should be enabled.""" simple_factories = ['Execution Layer', 'MCO'] if self.system_state.selected_factory_name in simple_factories: return True if self.system_state.entity_creator is None \ or self.system_state.entity_creator.model is None: return False return True @cached_property def _get_add_button_visible(self): """ Determines if the add button in the UI should be enabled""" return self.system_state.selected_factory_name != 'None' @cached_property def _get_remove_button_visible(self): """ Determines if the add button in the UI should be visible""" return self.system_state.selected_factory_name == 'Data Source' @cached_property def _get_add_new_entity_label(self): """Returns the label displayed on add_new_entity_btn""" return 'Add New {!s}'.format(self.system_state.selected_factory_name) @cached_property def _get_current_info(self): return WorkflowInfo( workflow_filename=self.task.current_file, plugins=self.task.lookup_plugins(), selected_factory_name=self.system_state.selected_factory_name, error_message=self.error_message ) # Synchronisation with WorkflowTree @on_trait_change('system_state:selected_view.error_message') def sync_selected_view(self): """ Synchronise selected_view with the selected modelview in the tree editor. Checks if the model held by the modelview needs to be displayed in the UI.""" if self.system_state.selected_view is not None: if isinstance(self.system_state.selected_view.model, BaseModel): self.selected_model = self.system_state.selected_view.model else: self.selected_model = None self.error_message = self.system_state.selected_view.error_message # Button event handlers for creating and deleting workflow items def _add_new_entity_btn_fired(self): """Calls add_new_entity when add_new_entity_btn is clicked""" self.system_state.add_new_entity() def _remove_entity_btn_fired(self): """Calls remove_entity when remove_entity_btn is clicked""" self.system_state.remove_entity() # ------------------- # Private Methods # ------------------- def _console_ns_default(self): namespace = { "task": self.task } try: namespace["app"] = self.task.window.application except AttributeError: namespace["app"] = None return namespace
class BaselineView(HasTraits): python_console_cmds = Dict() table = List() logging_b = Bool(False) directory_name_b = File plot = Instance(Plot) plot_data = Instance(ArrayPlotData) running = Bool(True) zoomall = Bool(False) position_centered = Bool(False) clear_button = SVGButton(label='', tooltip='Clear', filename=os.path.join(determine_path(), 'images', 'iconic', 'x.svg'), width=16, height=16) zoomall_button = SVGButton(label='', tooltip='Zoom All', toggle=True, filename=os.path.join(determine_path(), 'images', 'iconic', 'fullscreen.svg'), width=16, height=16) center_button = SVGButton(label='', tooltip='Center on Baseline', toggle=True, filename=os.path.join(determine_path(), 'images', 'iconic', 'target.svg'), width=16, height=16) paused_button = SVGButton(label='', tooltip='Pause', toggle_tooltip='Run', toggle=True, filename=os.path.join(determine_path(), 'images', 'iconic', 'pause.svg'), toggle_filename=os.path.join( determine_path(), 'images', 'iconic', 'play.svg'), width=16, height=16) reset_button = Button(label='Reset Filters') reset_iar_button = Button(label='Reset IAR') init_base_button = Button(label='Init. with known baseline') traits_view = View( HSplit( Item('table', style='readonly', editor=TabularEditor(adapter=SimpleAdapter()), show_label=False, width=0.3), VGroup( HGroup( Item('paused_button', show_label=False), Item('clear_button', show_label=False), Item('zoomall_button', show_label=False), Item('center_button', show_label=False), Item('reset_button', show_label=False), Item('reset_iar_button', show_label=False), Item('init_base_button', show_label=False), ), Item( 'plot', show_label=False, editor=ComponentEditor(bgcolor=(0.8, 0.8, 0.8)), )))) def _zoomall_button_fired(self): self.zoomall = not self.zoomall def _center_button_fired(self): self.position_centered = not self.position_centered def _paused_button_fired(self): self.running = not self.running def _reset_button_fired(self): self.link(MsgResetFilters(filter=0)) def _reset_iar_button_fired(self): self.link(MsgResetFilters(filter=1)) def _init_base_button_fired(self): self.link(MsgInitBase()) def _clear_button_fired(self): self.neds[:] = np.NAN self.fixeds[:] = False self.plot_data.set_data('n_fixed', []) self.plot_data.set_data('e_fixed', []) self.plot_data.set_data('d_fixed', []) self.plot_data.set_data('n_float', []) self.plot_data.set_data('e_float', []) self.plot_data.set_data('d_float', []) self.plot_data.set_data('t', []) self.plot_data.set_data('cur_fixed_n', []) self.plot_data.set_data('cur_fixed_e', []) self.plot_data.set_data('cur_fixed_d', []) self.plot_data.set_data('cur_float_n', []) self.plot_data.set_data('cur_float_e', []) self.plot_data.set_data('cur_float_d', []) def iar_state_callback(self, sbp_msg, **metadata): self.num_hyps = sbp_msg.num_hyps self.last_hyp_update = time.time() def _baseline_callback_ned(self, sbp_msg, **metadata): # Updating an ArrayPlotData isn't thread safe (see chaco issue #9), so # actually perform the update in the UI thread. if self.running: GUI.invoke_later(self.baseline_callback, sbp_msg) def update_table(self): self._table_list = self.table.items() def gps_time_callback(self, sbp_msg, **metadata): self.week = MsgGPSTime(sbp_msg).wn self.nsec = MsgGPSTime(sbp_msg).ns def mode_string(self, msg): if msg: self.fixed = (msg.flags & 1) == 1 if self.fixed: return 'Fixed RTK' else: return 'Float' return 'None' def baseline_callback(self, sbp_msg): self.last_btime_update = time.time() soln = MsgBaselineNED(sbp_msg) self.last_soln = soln table = [] soln.n = soln.n * 1e-3 soln.e = soln.e * 1e-3 soln.d = soln.d * 1e-3 dist = np.sqrt(soln.n**2 + soln.e**2 + soln.d**2) tow = soln.tow * 1e-3 if self.nsec is not None: tow += self.nsec * 1e-9 if self.week is not None: t = datetime.datetime(1980, 1, 6) + \ datetime.timedelta(weeks=self.week) + \ datetime.timedelta(seconds=tow) table.append(('GPS Time', t)) table.append(('GPS Week', str(self.week))) if self.directory_name_b == '': filepath = time.strftime("baseline_log_%Y%m%d-%H%M%S.csv") else: filepath = os.path.join( self.directory_name_b, time.strftime("baseline_log_%Y%m%d-%H%M%S.csv")) if self.logging_b == False: self.log_file = None if self.logging_b: if self.log_file is None: self.log_file = open(filepath, 'w') self.log_file.write( 'time,north(meters),east(meters),down(meters),distance(meters),num_signals,flags,num_hypothesis\n' ) self.log_file.write('%s,%.4f,%.4f,%.4f,%.4f,%d,0x%02x,%d\n' % (str(t), soln.n, soln.e, soln.d, dist, soln.n_sats, soln.flags, self.num_hyps)) self.log_file.flush() table.append(('GPS ToW', tow)) table.append(('N', soln.n)) table.append(('E', soln.e)) table.append(('D', soln.d)) table.append(('Dist.', dist)) table.append(('Num. Signals.', soln.n_sats)) table.append(('Flags', '0x%02x' % soln.flags)) table.append(('Mode', self.mode_string(soln))) if time.time() - self.last_hyp_update < 10 and self.num_hyps != 1: table.append(('IAR Num. Hyps.', self.num_hyps)) else: table.append(('IAR Num. Hyps.', "None")) # Rotate array, deleting oldest entries to maintain # no more than N in plot self.neds[1:] = self.neds[:-1] self.fixeds[1:] = self.fixeds[:-1] # Insert latest position self.neds[0][:] = [soln.n, soln.e, soln.d] self.fixeds[0] = self.fixed neds_fixed = self.neds[self.fixeds] neds_float = self.neds[np.logical_not(self.fixeds)] if not all(map(any, np.isnan(neds_fixed))): self.plot_data.set_data('n_fixed', neds_fixed.T[0]) self.plot_data.set_data('e_fixed', neds_fixed.T[1]) self.plot_data.set_data('d_fixed', neds_fixed.T[2]) if not all(map(any, np.isnan(neds_float))): self.plot_data.set_data('n_float', neds_float.T[0]) self.plot_data.set_data('e_float', neds_float.T[1]) self.plot_data.set_data('d_float', neds_float.T[2]) if self.fixed: self.plot_data.set_data('cur_fixed_n', [soln.n]) self.plot_data.set_data('cur_fixed_e', [soln.e]) self.plot_data.set_data('cur_fixed_d', [soln.d]) self.plot_data.set_data('cur_float_n', []) self.plot_data.set_data('cur_float_e', []) self.plot_data.set_data('cur_float_d', []) else: self.plot_data.set_data('cur_float_n', [soln.n]) self.plot_data.set_data('cur_float_e', [soln.e]) self.plot_data.set_data('cur_float_d', [soln.d]) self.plot_data.set_data('cur_fixed_n', []) self.plot_data.set_data('cur_fixed_e', []) self.plot_data.set_data('cur_fixed_d', []) self.plot_data.set_data('ref_n', [0.0]) self.plot_data.set_data('ref_e', [0.0]) self.plot_data.set_data('ref_d', [0.0]) if self.position_centered: d = (self.plot.index_range.high - self.plot.index_range.low) / 2. self.plot.index_range.set_bounds(soln.e - d, soln.e + d) d = (self.plot.value_range.high - self.plot.value_range.low) / 2. self.plot.value_range.set_bounds(soln.n - d, soln.n + d) if self.zoomall: plot_square_axes(self.plot, ('e_fixed', 'e_float'), ('n_fixed', 'n_float')) self.table = table def __init__(self, link, plot_history_max=1000, dirname=''): super(BaselineView, self).__init__() self.log_file = None self.directory_name_b = dirname self.num_hyps = 0 self.last_hyp_update = 0 self.last_btime_update = 0 self.last_soln = None self.plot_data = ArrayPlotData(n_fixed=[0.0], e_fixed=[0.0], d_fixed=[0.0], n_float=[0.0], e_float=[0.0], d_float=[0.0], t=[0.0], ref_n=[0.0], ref_e=[0.0], ref_d=[0.0], cur_fixed_e=[], cur_fixed_n=[], cur_fixed_d=[], cur_float_e=[], cur_float_n=[], cur_float_d=[]) self.plot_history_max = plot_history_max self.neds = np.empty((plot_history_max, 3)) self.neds[:] = np.NAN self.fixeds = np.zeros(plot_history_max, dtype=bool) self.plot = Plot(self.plot_data) color_float = (0.5, 0.5, 1.0) color_fixed = 'orange' pts_float = self.plot.plot(('e_float', 'n_float'), type='scatter', color=color_float, marker='dot', line_width=0.0, marker_size=1.0) pts_fixed = self.plot.plot(('e_fixed', 'n_fixed'), type='scatter', color=color_fixed, marker='dot', line_width=0.0, marker_size=1.0) lin = self.plot.plot(('e_fixed', 'n_fixed'), type='line', color=(1, 0.65, 0, 0.1)) ref = self.plot.plot(('ref_e', 'ref_n'), type='scatter', color='red', marker='plus', marker_size=5, line_width=1.5) cur_fixed = self.plot.plot(('cur_fixed_e', 'cur_fixed_n'), type='scatter', color=color_fixed, marker='plus', marker_size=5, line_width=1.5) cur_float = self.plot.plot(('cur_float_e', 'cur_float_n'), type='scatter', color=color_float, marker='plus', marker_size=5, line_width=1.5) plot_labels = ['Base Position', 'RTK Fixed', 'RTK Float'] plots_legend = dict(zip(plot_labels, [ref, cur_fixed, cur_float])) self.plot.legend.plots = plots_legend self.plot.legend.visible = True self.plot.index_axis.tick_label_position = 'inside' self.plot.index_axis.tick_label_color = 'gray' self.plot.index_axis.tick_color = 'gray' self.plot.index_axis.title = 'E (meters)' self.plot.index_axis.title_spacing = 5 self.plot.value_axis.tick_label_position = 'inside' self.plot.value_axis.tick_label_color = 'gray' self.plot.value_axis.tick_color = 'gray' self.plot.value_axis.title = 'N (meters)' self.plot.value_axis.title_spacing = 5 self.plot.padding = (25, 25, 25, 25) self.plot.tools.append(PanTool(self.plot)) zt = ZoomTool(self.plot, zoom_factor=1.1, tool_mode="box", always_on=False) self.plot.overlays.append(zt) self.week = None self.nsec = 0 self.link = link self.link.add_callback(self._baseline_callback_ned, SBP_MSG_BASELINE_NED) self.link.add_callback(self.iar_state_callback, SBP_MSG_IAR_STATE) self.link.add_callback(self.gps_time_callback, SBP_MSG_GPS_TIME) self.python_console_cmds = {'baseline': self}
class config(BaseWorkflowConfig): uuid = traits.Str(desc="UUID") desc = traits.Str(desc='Workflow description') # Directories base_dir = Directory( os.path.abspath('.'), mandatory=True, desc='Base directory of data. (Should be subject-independent)') sink_dir = Directory(os.path.abspath('.'), mandatory=True, desc="Location where the BIP will store the results") field_dir = Directory( desc="Base directory of field-map data (Should be subject-independent) \ Set this value to None if you don't want fieldmap distortion correction" ) surf_dir = Directory(mandatory=True, desc="Freesurfer subjects directory") # Subjects subjects = traits.List( traits.Str, mandatory=True, usedefault=True, desc="Subject id's. Note: These MUST match the subject id's in the \ Freesurfer directory. For simplicity, the subject id's should \ also match with the location of individual functional files." ) func_template = traits.String('%s/functional.nii.gz') run_datagrabber_without_submitting = traits.Bool( desc="Run the datagrabber without \ submitting to the cluster") timepoints_to_remove = traits.Int(0, usedefault=True) # Fieldmap use_fieldmap = Bool( False, mandatory=False, usedefault=True, desc='True to include fieldmap distortion correction. Note: field_dir \ must be specified') magnitude_template = traits.String('%s/magnitude.nii.gz') phase_template = traits.String('%s/phase.nii.gz') TE_diff = traits.Float(desc='difference in B0 field map TEs') sigma = traits.Int( 2, desc='2D spatial gaussing smoothing stdev (default = 2mm)') echospacing = traits.Float(desc="EPI echo spacing") # Motion Correction do_slicetiming = Bool(True, usedefault=True, desc="Perform slice timing correction") SliceOrder = traits.List(traits.Int) TR = traits.Float(1.0, mandatory=True, desc="TR of functional") motion_correct_node = traits.Enum( 'nipy', 'fsl', 'spm', 'afni', desc="motion correction algorithm to use", usedefault=True, ) loops = traits.List([5], traits.Int(5), usedefault=True) #between_loops = traits.Either("None",traits.List([5]),usedefault=True) speedup = traits.List([5], traits.Int(5), usedefault=True) # Artifact Detection norm_thresh = traits.Float(1, min=0, usedefault=True, desc="norm thresh for art") z_thresh = traits.Float(3, min=0, usedefault=True, desc="z thresh for art") # Smoothing fwhm = traits.List( [0, 5], traits.Float(), mandatory=True, usedefault=True, desc="Full width at half max. The data will be smoothed at all values \ specified in this list.") smooth_type = traits.Enum("susan", "isotropic", 'freesurfer', usedefault=True, desc="Type of smoothing to use") surface_fwhm = traits.Float( 0.0, desc='surface smoothing kernel, if freesurfer is selected', usedefault=True) # CompCor compcor_select = traits.BaseTuple( traits.Bool, traits.Bool, mandatory=True, desc="The first value in the list corresponds to applying \ t-compcor, and the second value to a-compcor. Note: \ both can be true") num_noise_components = traits.Int( 6, usedefault=True, desc="number of principle components of the noise to use") regress_before_PCA = traits.Bool(True) # Highpass Filter hpcutoff = traits.Float(128., desc="highpass cutoff", usedefault=True) #zscore do_zscore = Bool(False) # Advanced Options use_advanced_options = traits.Bool() advanced_script = traits.Code() debug = traits.Bool(False) # Buttons check_func_datagrabber = Button("Check") check_field_datagrabber = Button("Check") def _check_func_datagrabber_fired(self): subs = self.subjects for s in subs: if not os.path.exists( os.path.join(self.base_dir, self.func_template % s)): print "ERROR", os.path.join(self.base_dir, self.func_template % s), "does NOT exist!" break else: print os.path.join(self.base_dir, self.func_template % s), "exists!" def _check_field_datagrabber_fired(self): subs = self.subjects for s in subs: if not os.path.exists( os.path.join(self.field_dir, self.magnitude_template % s)): print "ERROR:", os.path.join(self.field_dir, self.magnitude_template % s), "does NOT exist!" break else: print os.path.join(self.base_dir, self.magnitude_template % s), "exists!" if not os.path.exists( os.path.join(self.field_dir, self.phase_template % s)): print "ERROR:", os.path.join( self.field_dir, self.phase_template % s), "does NOT exist!" break else: print os.path.join(self.base_dir, self.phase_template % s), "exists!"
class StartPage(Editor): """ A Pyface Tasks Editor to hold the opening page """ #: The model object to view. If not specified, the editor is used instead. model = Instance(HasTraits) #: The UI object associated with the Traits view, if it has been #: constructed. ui = Instance('traitsui.ui.UI') #: The editor's user-visible name. name = Str('Start Page') #: The task associated with the editor. task = Any() #: Button to open a new data file. open_data_file_button = Button(label='Open Data File', style='button') def default_traits_view(self): # pylint: disable=no-self-use """ Create the default traits View object for the model Returns ------- default_traits_view : :py:class:`traitsui.view.View` The default traits View object for the model """ return View( Group(Spring(), Group(Spring(), Item('open_data_file_button', show_label=False), Spring(), orientation='horizontal'), Spring(), orientation='vertical')) def create(self, parent): """ Create and set the widget(s) for the Editor. Parameters ---------- parent : toolkit-specific widget The parent widget for the Editor """ self.ui = self.edit_traits(kind='subpanel', parent=parent) # pylint: disable=invalid-name self.control = self.ui.control def destroy(self): """ Destroy the Editor and clean up after """ self.control = None if self.ui is not None: self.ui.dispose() self.ui = None @observe('open_data_file_button', post_init=True) def open(self, event): """ Open new data file. Parameters ---------- event : A :py:class:`traits.observation.events.TraitChangeEvent` instance The trait change event for open_data_file_button """ self.task.open()
class FiducialsPanel(HasPrivateTraits): """Set fiducials on an MRI surface""" model = Instance(MRIHeadWithFiducialsModel) fid_file = DelegatesTo('model') fid_fname = DelegatesTo('model') lpa = DelegatesTo('model') nasion = DelegatesTo('model') rpa = DelegatesTo('model') can_save = DelegatesTo('model') can_save_as = DelegatesTo('model') can_reset = DelegatesTo('model') fid_ok = DelegatesTo('model') locked = DelegatesTo('model', 'lock_fiducials') set = Enum('LPA', 'Nasion', 'RPA') current_pos = Array(float, (1, 3)) # for editing save_as = Button(label='Save As...') save = Button(label='Save') reset_fid = Button(label="Reset to File") headview = Instance(HeadViewController) hsp_obj = Instance(SurfaceObject) picker = Instance(object) # the layout of the dialog created view = View( VGroup(Item('fid_file', label='Fiducials File'), Item('fid_fname', show_label=False, style='readonly'), Item('set', style='custom'), Item('current_pos', label='Pos'), HGroup(Item('save', enabled_when='can_save', tooltip="If a filename is currently " "specified, save to that file, otherwise " "save to the default file name"), Item('save_as', enabled_when='can_save_as'), Item('reset_fid', enabled_when='can_reset'), show_labels=False), enabled_when="locked==False")) def __init__(self, *args, **kwargs): super(FiducialsPanel, self).__init__(*args, **kwargs) self.sync_trait('lpa', self, 'current_pos', mutual=True) def _reset_fid_fired(self): self.model.reset = True def _save_fired(self): self.model.save() def _save_as_fired(self): if self.fid_file: default_path = self.fid_file else: default_path = self.model.default_fid_fname dlg = FileDialog(action="save as", wildcard=fid_wildcard, default_path=default_path) dlg.open() if dlg.return_code != OK: return path = dlg.path if not path.endswith('.fif'): path = path + '.fif' if os.path.exists(path): answer = confirm( None, "The file %r already exists. Should it " "be replaced?", "Overwrite File?") if answer != YES: return self.model.save(path) def _on_pick(self, picker): if self.locked: return self.picker = picker n_pos = len(picker.picked_positions) if n_pos == 0: logger.debug("GUI: picked empty location") return if picker.actor is self.hsp_obj.surf.actor.actor: idxs = [] idx = None pt = [picker.pick_position] elif self.hsp_obj.surf.actor.actor in picker.actors: idxs = [ i for i in range(n_pos) if picker.actors[i] is self.hsp_obj.surf.actor.actor ] idx = idxs[-1] pt = [picker.picked_positions[idx]] else: logger.debug("GUI: picked object other than MRI") round_ = lambda x: round(x, 3) poss = [map(round_, pos) for pos in picker.picked_positions] pos = map(round_, picker.pick_position) msg = ["Pick Event: %i picked_positions:" % n_pos] line = str(pos) if idx is None: line += " <-pick_position" msg.append(line) for i, pos in enumerate(poss): line = str(pos) if i == idx: line += " <- MRI mesh" elif i in idxs: line += " (<- also MRI mesh)" msg.append(line) logger.debug(os.linesep.join(msg)) if self.set == 'Nasion': self.nasion = pt elif self.set == 'LPA': self.lpa = pt elif self.set == 'RPA': self.rpa = pt else: raise ValueError("set = %r" % self.set) @on_trait_change('set') def _on_set_change(self, obj, name, old, new): self.sync_trait(old.lower(), self, 'current_pos', mutual=True, remove=True) self.sync_trait(new.lower(), self, 'current_pos', mutual=True) if new == 'Nasion': self.headview.front = True elif new == 'LPA': self.headview.left = True elif new == 'RPA': self.headview.right = True
class FieldViewer(HasTraits): # 三个轴的取值范围 x0, x1 = Float(-5), Float(5) y0, y1 = Float(-5), Float(5) z0, z1 = Float(-5), Float(5) points = Int(50) # 分割点数 autocontour = Bool(True) # 是否自动计算等值面 v0, v1 = Float(0.0), Float(1.0) # 等值面的取值范围 contour = Range("v0", "v1", 0.5) # 等值面的值 function = Str("x*x*0.5 + y*y + z*z*2.0") # 标量场函数 function_list = [ "x*x*0.5 + y*y + z*z*2.0", "x*y*0.5 + np.sin(2*x)*y +y*z*2.0", "x*y*z", "np.sin((x*x+y*y)/z)" ] plotbutton = Button(u"描画") scene = Instance(MlabSceneModel, ()) #❶ view = View( HSplit( VGroup( "x0", "x1", "y0", "y1", "z0", "z1", Item('points', label=u"点数"), Item('autocontour', label=u"自动等值"), Item('plotbutton', show_label=False), ), VGroup( Item( 'scene', editor=SceneEditor(scene_class=MayaviScene), #❷ resizable=True, height=300, width=350), Item('function', editor=EnumEditor(name='function_list', evaluate=lambda x: x)), Item('contour', editor=RangeEditor(format="%1.2f", low_name="v0", high_name="v1")), show_labels=False)), width=500, resizable=True, title=u"三维标量场观察器") def _plotbutton_fired(self): self.plot() def plot(self): # 产生三维网格 x, y, z = np.mgrid[ #❸ self.x0:self.x1:1j * self.points, self.y0:self.y1:1j * self.points, self.z0:self.z1:1j * self.points] # 根据函数计算标量场的值 scalars = eval(self.function) #❹ self.scene.mlab.clf() # 清空当前场景 # 绘制等值平面 g = self.scene.mlab.contour3d(x, y, z, scalars, contours=8, transparent=True) #❺ g.contour.auto_contours = self.autocontour self.scene.mlab.axes(figure=self.scene.mayavi_scene) # 添加坐标轴 # 添加一个X-Y的切面 s = self.scene.mlab.pipeline.scalar_cut_plane(g) cutpoint = (self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2, ( self.z0 + self.z1) / 2 s.implicit_plane.normal = (0, 0, 1) # x cut s.implicit_plane.origin = cutpoint self.g = g #❻ self.scalars = scalars # 计算标量场的值的范围 self.v0 = np.min(scalars) self.v1 = np.max(scalars) def _contour_changed(self): #❼ if hasattr(self, "g"): if not self.g.contour.auto_contours: self.g.contour.contours = [self.contour] def _autocontour_changed(self): #❽ if hasattr(self, "g"): self.g.contour.auto_contours = self.autocontour if not self.autocontour: self._contour_changed()
class SbpRelayView(HasTraits): """ SBP Relay view- Class allows user to specify port, IP address, and message set to relay over UDP and to configure a skylark connection """ running = Bool(False) configured = Bool(False) broadcasting = Bool(False) msg_enum = Enum('Observations', 'All') ip_ad = String(DEFAULT_UDP_ADDRESS) port = Int(DEFAULT_UDP_PORT) information = String('UDP Streaming\n\nBroadcast SBP information received by' ' the console to other machines or processes over UDP. With the \'Observations\'' ' radio button selected, the console will broadcast the necessary information' ' for a rover Piksi to acheive an RTK solution.' '\n\nThis can be used to stream observations to a remote Piksi through' ' aircraft telemetry via ground control software such as MAVProxy or' ' Mission Planner.') http_information = String('Skylark - Experimental Piksi Networking\n\n' "Skylark is Swift Navigation's Internet service for connecting Piksi receivers without the use of a radio. To receive GPS observations from the closest nearby Piksi base station (within 5km), click Connect to Skylark.\n\n") start = Button(label='Start', toggle=True, width=32) stop = Button(label='Stop', toggle=True, width=32) connected_rover = Bool(False) connect_rover = Button(label='Connect to Skylark', toggle=True, width=32) disconnect_rover = Button(label='Disconnect from Skylark', toggle=True, width=32) skylark_url = String() base_pragma = String() rover_pragma = String() base_device_uid = String() rover_device_uid = String() toggle=True view = View( VGroup( spring, HGroup( VGroup( Item('running', show_label=True, style='readonly', visible_when='running'), Item('msg_enum', label="Messages to broadcast", style='custom', enabled_when='not running'), Item('ip_ad', label='IP Address', enabled_when='not running'), Item('port', label="Port", enabled_when='not running'), HGroup( spring, UItem('start', enabled_when='not running', show_label=False), UItem('stop', enabled_when='running', show_label=False), spring)), VGroup( Item('information', label="Notes", height=10, editor=MultilineTextEditor(TextEditor(multi_line=True)), style='readonly', show_label=False, resizable=True, padding=15), spring, ), ), spring, HGroup( VGroup( HGroup( spring, UItem('connect_rover', enabled_when='not connected_rover', show_label=False), UItem('disconnect_rover', enabled_when='connected_rover', show_label=False), spring), HGroup(Spring(springy=False, width=2), Item('skylark_url', enabled_when='not connected_rover', show_label=True), Spring(springy=False, width=2) ), HGroup(spring, Item('base_pragma', label='Base option '), Item('base_device_uid', label='Base device '), spring), HGroup(spring, Item('rover_pragma', label='Rover option'), Item('rover_device_uid', label='Rover device'), spring),), VGroup( Item('http_information', label="Notes", height=10, editor=MultilineTextEditor(TextEditor(multi_line=True)), style='readonly', show_label=False, resizable=True, padding=15), spring, ), ), spring ) ) def __init__(self, link, device_uid=None, base=DEFAULT_BASE, whitelist=None, rover_pragma='', base_pragma='', rover_uuid='', base_uuid='', connect=False, verbose=False): """ Traits tab with UI for UDP broadcast of SBP. Parameters ---------- link : sbp.client.handler.Handler Link for SBP transfer to/from Piksi. device_uid : str Piksi Device UUID (defaults to None) base : str HTTP endpoint whitelist : [int] | None Piksi Device UUID (defaults to None) """ self.link = link # Whitelist used for UDP broadcast view self.msgs = OBS_MSGS # register a callback when the msg_enum trait changes self.on_trait_change(self.update_msgs, 'msg_enum') # Whitelist used for Skylark broadcasting self.whitelist = whitelist self.device_uid = None self.python_console_cmds = {'update': self} self.rover_pragma = rover_pragma self.base_pragma = base_pragma self.rover_device_uid = rover_uuid self.base_device_uid = base_uuid self.verbose = verbose self.skylark_watchdog_thread = None self.skylark_url = base if connect: self.connect_when_uuid_received = True else: self.connect_when_uuid_received = False def update_msgs(self): """Updates the instance variable msgs which store the msgs that we will send over UDP. """ if self.msg_enum == 'Observations': self.msgs = OBS_MSGS elif self.msg_enum == 'All': self.msgs = [None] else: raise NotImplementedError def set_route(self, uuid=None, serial_id=None, channel=CHANNEL_UUID): """Sets serial_id hash for HTTP headers. Parameters ---------- uuid: str real uuid of device serial_id : int Piksi device ID channel : str UUID namespace for device UUID """ if uuid: device_uid = uuid elif serial_id: device_uid = str(get_uuid(channel, serial_id % 1000)) else: print("Improper call of set_route, either a serial number or UUID should be passed") device_uid = str(get_uuid(channel, 1234)) print(("Setting UUID to default value of {0}".format(device_uid))) self.device_uid = device_uid def _prompt_setting_error(self, text): """Nonblocking prompt for a device setting error. Parameters ---------- text : str Helpful error message for the user """ prompt = CallbackPrompt(title="Setting Error", actions=[close_button]) prompt.text = text prompt.run(block=False) def _disconnect_rover_fired(self): """Handle callback for HTTP rover disconnects. """ try: if isinstance(self.skylark_watchdog_thread, threading.Thread) and \ not self.skylark_watchdog_thread.stopped(): self.skylark_watchdog_thread.stop() else: print((("Unable to disconnect: Skylark watchdog thread " "inititalized at {0} and connected since {1} has " "already been stopped").format(self.skylark_watchdog_thread.get_init_time(), self.skylark_watchdog_thread.get_connect_time()))) self.connected_rover = False except: self.connected_rover = False import traceback print((traceback.format_exc())) def _connect_rover_fired(self): """Handle callback for HTTP rover connections. Launches an instance of skylark_watchdog_thread. """ if not self.device_uid: msg = "\nDevice ID not found!\n\nConnection requires a valid Piksi device ID." self._prompt_setting_error(msg) return try: _base_device_uid = self.base_device_uid or self.device_uid _rover_device_uid = self.rover_device_uid or self.device_uid config = SkylarkConsoleConnectConfig(self.link, self.device_uid, self.skylark_url, self.whitelist, self.rover_pragma, self.base_pragma, _rover_device_uid, _base_device_uid) self.skylark_watchdog_thread = SkylarkWatchdogThread(link=self.link, skylark_config=config, stopped_callback=self._disconnect_rover_fired, verbose=self.verbose) self.connected_rover = True self.skylark_watchdog_thread.start() except: if isinstance(self.skylark_watchdog_thread, threading.Thread) \ and self.skylark_watchdog_thread.stopped(): self.skylark_watchdog_thread.stop() self.connected_rover = False import traceback print((traceback.format_exc())) def _start_fired(self): """Handle start udp broadcast button. Registers callbacks on self.link for each of the self.msgs If self.msgs is None, it registers one generic callback for all messages. """ self.running = True try: self.func = UdpLogger(self.ip_ad, self.port) self.link.add_callback(self.func, self.msgs) except: import traceback print((traceback.format_exc())) def _stop_fired(self): """Handle the stop udp broadcast button. It uses the self.funcs and self.msgs to remove the callbacks that were registered when the start button was pressed. """ try: self.link.remove_callback(self.func, self.msgs) self.func.__exit__() self.func = None self.running = False except: import traceback print((traceback.format_exc()))
class InterpolatorView(HasTraits): # The bounds on which to interpolate. bounds = Array(cols=3, dtype=float, desc='spatial bounds for the interpolation ' '(xmin, xmax, ymin, ymax, zmin, zmax)') # The number of points to interpolate onto. num_points = Int(100000, enter_set=True, auto_set=False, desc='number of points on which to interpolate') # The particle arrays to interpolate from. particle_arrays = List # The scalar to interpolate. scalar = Str('rho', desc='name of the active scalar to view') # Sync'd trait with the scalar lut manager. show_legend = Bool(False, desc='if the scalar legend is to be displayed') # Enable/disable the interpolation visible = Bool(False, desc='if the interpolation is to be displayed') # A button to use the set bounds. set_bounds = Button('Set Bounds') # A button to recompute the bounds. recompute_bounds = Button('Recompute Bounds') # Private traits. ###################################################### # The interpolator we are a view for. interpolator = Instance(Interpolator) # The mlab plot for this particle array. plot = Instance(PipelineBase) scalar_list = List scene = Instance(MlabSceneModel) source = Instance(PipelineBase) _arrays_changed = Bool(False) # View definition ###################################################### view = View( Item(name='visible'), Item(name='scalar', editor=EnumEditor(name='scalar_list')), Item(name='num_points'), Item(name='bounds'), Item(name='set_bounds', show_label=False), Item(name='recompute_bounds', show_label=False), Item(name='show_legend'), ) # Private protocol ################################################### def _change_bounds(self): interp = self.interpolator if interp is not None: interp.set_domain(self.bounds, self.interpolator.shape) self._update_plot() def _setup_interpolator(self): if self.interpolator is None: interpolator = Interpolator(self.particle_arrays, num_points=self.num_points) self.bounds = interpolator.bounds self.interpolator = interpolator else: if self._arrays_changed: self.interpolator.update_particle_arrays(self.particle_arrays) self._arrays_changed = False # Trait handlers ##################################################### def _particle_arrays_changed(self, pas): if len(pas) > 0: all_props = reduce(set.union, [set(x.properties.keys()) for x in pas]) else: all_props = set() self.scalar_list = list(all_props) self._arrays_changed = True self._update_plot() def _num_points_changed(self, value): interp = self.interpolator if interp is not None: bounds = self.interpolator.bounds shape = get_nx_ny_nz(value, bounds) interp.set_domain(bounds, shape) self._update_plot() def _recompute_bounds_fired(self): bounds = get_bounding_box(self.particle_arrays) self.bounds = bounds self._change_bounds() def _set_bounds_fired(self): self._change_bounds() def _bounds_default(self): return [0, 1, 0, 1, 0, 1] @on_trait_change('scalar, visible') def _update_plot(self): if self.visible: mlab = self.scene.mlab self._setup_interpolator() interp = self.interpolator prop = interp.interpolate(self.scalar) if self.source is None: src = mlab.pipeline.scalar_field(interp.x, interp.y, interp.z, prop) self.source = src else: self.source.mlab_source.reset(x=interp.x, y=interp.y, z=interp.z, scalars=prop) src = self.source if self.plot is None: if interp.dim == 3: plot = mlab.pipeline.scalar_cut_plane(src) else: plot = mlab.pipeline.surface(src) self.plot = plot scm = plot.module_manager.scalar_lut_manager scm.set(show_legend=self.show_legend, use_default_name=False, data_name=self.scalar) self.sync_trait('show_legend', scm, mutual=True) else: self.plot.visible = True scm = self.plot.module_manager.scalar_lut_manager scm.data_name = self.scalar else: if self.plot is not None: self.plot.visible = False
class TDViz(HasTraits): fitsfile = File(filter=[u"*.fits"]) plotbutton1 = Button(u"Plot") plotbutton2 = Button(u"Plot") plotbutton3 = Button(u"Plot") clearbutton = Button(u"Clear") scene = Instance(MlabSceneModel, ()) rendering = Enum("Surface-Spectrum", "Surface-Intensity", "Volume-Intensity") save_the_scene = Button(u"Save") save_in_file = Str("test.x3d") movie = Button(u"Movie") iteration = Int(0) quality = Int(8) delay = Int(0) angle = Int(360) spin = Button(u"Spin") zscale = Int(1) xstart = Int(0) xend = Int(1) ystart = Int(0) yend = Int(1) zstart = Int(0) zend = Int(1) datamin = Float(0.0) datamax = Float(1.0) opacity = Float(0.4) dist = Float(0.0) leng = Float(0.0) vsp = Float(0.0) contfile = File(filter=[u"*.fits"]) view = View(HSplit( VGroup( Item("fitsfile", label=u"Select a FITS datacube", show_label=True, editor=FileEditor(dialog_style='open')), Item("rendering", tooltip=u"Choose the rendering type", show_label=True), Item('plotbutton1', tooltip=u"Plot 3D surfaces, color coded by velocities", visible_when="rendering=='Surface-Spectrum'"), Item('plotbutton2', tooltip=u"Plot 3D surfaces, color coded by intensities", visible_when="rendering=='Surface-Intensity'"), Item('plotbutton3', tooltip=u"Plot 3D dots, color coded by intensities", visible_when="rendering=='Volume-Intensity'"), "clearbutton", HGroup( Item('xstart', tooltip=u"starting pixel in X axis", show_label=True, springy=True), Item('xend', tooltip=u"ending pixel in X axis", show_label=True, springy=True)), HGroup( Item('ystart', tooltip=u"starting pixel in Y axis", show_label=True, springy=True), Item('yend', tooltip=u"ending pixel in Y axis", show_label=True, springy=True)), HGroup( Item('zstart', tooltip=u"starting pixel in Z axis", show_label=True, springy=True), Item('zend', tooltip=u"ending pixel in Z axis", show_label=True, springy=True)), HGroup( Item('datamax', tooltip=u"Maximum datapoint shown", show_label=True, springy=True), Item('datamin', tooltip=u"Minimum datapoint shown", show_label=True, springy=True)), HGroup( Item('dist', tooltip=u"Put a distance in kpc", show_label=True), Item('leng', tooltip= u"Put a non-zero bar length in pc to show the scale bar", show_label=True), Item( 'vsp', tooltip= u"Put a non-zero velocity range in km/s to show the scale bar", show_label=True)), HGroup(Item('zscale', tooltip=u"Stretch the datacube in Z axis", show_label=True), Item('opacity', tooltip=u"Opacity of the scene", show_label=True), show_labels=False), Item('_'), Item( "contfile", label=u"Add background contours", tooltip= u"This file must be of the same (first two) dimension as the datacube!!!", show_label=True, editor=FileEditor(dialog_style='open')), Item('_'), HGroup(Item("spin", tooltip=u"Spin 360 degrees", show_label=False), Item("movie", tooltip="Make a GIF movie", show_label=False)), HGroup( Item('iteration', tooltip=u"number of iterations, 0 means inf.", show_label=True), Item('quality', tooltip=u"quality of plots, 0 is worst, 8 is good.", show_label=True)), HGroup( Item('delay', tooltip=u"time delay between frames, in millisecond.", show_label=True), Item('angle', tooltip=u"angle the cube spins", show_label=True)), Item('_'), HGroup( Item("save_the_scene", tooltip=u"Save current scene in a 3D model file"), Item("save_in_file", tooltip=u"3D model file name", show_label=False), visible_when= "rendering=='Surface-Spectrum' or rendering=='Surface-Intensity'" ), show_labels=False), VGroup(Item(name='scene', editor=SceneEditor(scene_class=MayaviScene), resizable=True, height=600, width=600), show_labels=False)), resizable=True, title=u"TDViz") def _fitsfile_changed(self): img = fits.open(self.fitsfile) # Read the fits data dat = img[0].data self.hdr = img[0].header naxis = self.hdr['NAXIS'] ## The three axes loaded by fits are: velo, dec, ra ## Swap the axes, RA<->velo if naxis == 4: self.data = np.swapaxes(dat[0], 0, 2) * 1000.0 elif naxis == 3: self.data = np.swapaxes(dat, 0, 2) * 1000.0 #onevpix = self.hdr['CDELT3'] self.data[np.isnan(self.data)] = 0.0 self.data[np.isinf(self.data)] = 0.0 self.datamax = np.asscalar(np.max(self.data)) self.datamin = np.asscalar(np.min(self.data)) self.xend = self.data.shape[0] - 1 self.yend = self.data.shape[1] - 1 self.zend = self.data.shape[2] - 1 self.data[self.data < self.datamin] = self.datamin def loaddata(self): channel = self.data ## Reset the range if it is beyond the cube: if self.xstart < 0: print('Wrong number!') self.xstart = 0 if self.xend > channel.shape[0] - 1: print('Wrong number!') self.xend = channel.shape[0] - 1 if self.ystart < 0: print('Wrong number!') self.ystart = 0 if self.yend > channel.shape[1] - 1: print('Wrong number!') self.yend = channel.shape[1] - 1 if self.zstart < 0: print('Wrong number!') self.zstart = 0 if self.zend > channel.shape[2] - 1: print('Wrong number!') self.zend = channel.shape[2] - 1 ## Select a region, use mJy unit region = channel[self.xstart:self.xend, self.ystart:self.yend, self.zstart:self.zend] ## Stretch the cube in V axis from scipy.interpolate import splrep from scipy.interpolate import splev vol = region.shape stretch = self.zscale ## Stretch parameter: how many times longer the V axis will be sregion = np.empty((vol[0], vol[1], vol[2] * stretch)) chanindex = np.linspace(0, vol[2] - 1, vol[2]) for j in range(0, vol[0] - 1): for k in range(0, vol[1] - 1): spec = region[j, k, :] tck = splrep(chanindex, spec, k=1) chanindex2 = np.linspace(0, vol[2] - 1, vol[2] * stretch) sregion[j, k, :] = splev(chanindex2, tck) self.sregion = sregion # Reset the max/min values if self.datamin < np.asscalar(np.min(self.sregion)): print('Wrong number!') self.datamin = np.asscalar(np.min(self.sregion)) if self.datamax > np.asscalar(np.max(self.sregion)): print('Wrong number!') self.datamax = np.asscalar(np.max(self.sregion)) self.xrang = abs(self.xstart - self.xend) self.yrang = abs(self.ystart - self.yend) self.zrang = abs(self.zstart - self.zend) * stretch ## Keep a record of the coordinates: crval1 = self.hdr['crval1'] cdelt1 = self.hdr['cdelt1'] crpix1 = self.hdr['crpix1'] crval2 = self.hdr['crval2'] cdelt2 = self.hdr['cdelt2'] crpix2 = self.hdr['crpix2'] crval3 = self.hdr['crval3'] cdelt3 = self.hdr['cdelt3'] crpix3 = self.hdr['crpix3'] ra_start = (self.xstart + 1 - crpix1) * cdelt1 + crval1 ra_end = (self.xend + 1 - crpix1) * cdelt1 + crval1 #if ra_start < ra_end: # ra_start, ra_end = ra_end, ra_start dec_start = (self.ystart + 1 - crpix2) * cdelt2 + crval2 dec_end = (self.yend + 1 - crpix2) * cdelt2 + crval2 #if dec_start > dec_end: # dec_start, dec_end = dec_end, dec_start vel_start = (self.zstart + 1 - crpix3) * cdelt3 + crval3 vel_end = (self.zend + 1 - crpix3) * cdelt3 + crval3 #if vel_start < vel_end: # vel_start, vel_end = vel_end, vel_start vel_start /= 1e3 vel_end /= 1e3 ## Flip the V axis if cdelt3 > 0: self.sregion = self.sregion[:, :, ::-1] vel_start, vel_end = vel_end, vel_start self.extent = [ ra_start, ra_end, dec_start, dec_end, vel_start, vel_end ] def labels(self): ''' Add 3d text to show the axes. ''' fontsize = max(self.xrang, self.yrang) / 40. tcolor = (1, 1, 1) mlab.text3d(self.xrang / 2, -10, self.zrang + 10, 'R.A.', scale=fontsize, orient_to_camera=True, color=tcolor) mlab.text3d(-10, self.yrang / 2, self.zrang + 10, 'Decl.', scale=fontsize, orient_to_camera=True, color=tcolor) mlab.text3d(-10, -10, self.zrang / 2 - 10, 'V (km/s)', scale=fontsize, orient_to_camera=True, color=tcolor) # Add scale bars if self.leng != 0.0: distance = self.dist * 1e3 length = self.leng leng_pix = np.round(length / distance / np.pi * 180. / np.abs(self.hdr['cdelt1'])) bar_x = [self.xrang - 20 - leng_pix, self.xrang - 20] bar_y = [self.yrang - 10, self.yrang - 10] bar_z = [0, 0] mlab.plot3d(bar_x, bar_y, bar_z, color=tcolor, tube_radius=1.) mlab.text3d(self.xrang - 30 - leng_pix, self.yrang - 25, 0, '{:.2f} pc'.format(length), scale=fontsize, orient_to_camera=False, color=tcolor) if self.vsp != 0.0: vspan = self.vsp vspan_pix = np.round(vspan / np.abs(self.hdr['cdelt3'] / 1e3)) bar_x = [self.xrang, self.xrang] bar_y = [self.yrang - 10, self.yrang - 10] bar_z = np.array([5, 5 + vspan_pix]) * self.zscale mlab.plot3d(bar_x, bar_y, bar_z, color=tcolor, tube_radius=1.) mlab.text3d(self.xrang, self.yrang - 25, 10, '{:.1f} km/s'.format(vspan), scale=fontsize, orient_to_camera=False, color=tcolor, orientation=(0, 90, 0)) # Label the coordinates of the corners # Lower left corner ra0 = self.extent[0] dec0 = self.extent[2] c = SkyCoord(ra=ra0 * u.degree, dec=dec0 * u.degree, frame='icrs') RA_ll = str(int(c.ra.hms.h)) + 'h' + str(int(c.ra.hms.m)) + 'm' + str( round(c.ra.hms.s, 1)) + 's' mlab.text3d(0, -10, self.zrang + 5, RA_ll, scale=fontsize, orient_to_camera=True, color=tcolor) DEC_ll = str(int(c.dec.dms.d)) + 'd' + str(int(abs( c.dec.dms.m))) + 'm' + str(round(abs(c.dec.dms.s), 1)) + 's' mlab.text3d(-40, 0, self.zrang + 5, DEC_ll, scale=fontsize, orient_to_camera=True, color=tcolor) # Upper right corner ra0 = self.extent[1] dec0 = self.extent[3] c = SkyCoord(ra=ra0 * u.degree, dec=dec0 * u.degree, frame='icrs') RA_ll = str(int(c.ra.hms.h)) + 'h' + str(int(c.ra.hms.m)) + 'm' + str( round(c.ra.hms.s, 1)) + 's' mlab.text3d(self.xrang, -10, self.zrang + 5, RA_ll, scale=fontsize, orient_to_camera=True, color=tcolor) DEC_ll = str(int(c.dec.dms.d)) + 'd' + str(int(abs( c.dec.dms.m))) + 'm' + str(round(abs(c.dec.dms.s), 1)) + 's' mlab.text3d(-40, self.yrang, self.zrang + 5, DEC_ll, scale=fontsize, orient_to_camera=True, color=tcolor) # V axis if self.extent[5] > self.extent[4]: v0 = self.extent[4] v1 = self.extent[5] else: v0 = self.extent[5] v1 = self.extent[4] mlab.text3d(-10, -10, self.zrang, str(round(v0, 1)), scale=fontsize, orient_to_camera=True, color=tcolor) mlab.text3d(-10, -10, 0, str(round(v1, 1)), scale=fontsize, orient_to_camera=True, color=tcolor) mlab.axes(self.field, ranges=self.extent, x_axis_visibility=False, y_axis_visibility=False, z_axis_visibility=False) mlab.outline() def _plotbutton1_fired(self): mlab.clf() self.loaddata() self.sregion[np.where(self.sregion < self.datamin)] = self.datamin self.sregion[np.where(self.sregion > self.datamax)] = self.datamax # The following codes from: http://docs.enthought.com/mayavi/mayavi/auto/example_atomic_orbital.html#example-atomic-orbital field = mlab.pipeline.scalar_field( self.sregion) # Generate a scalar field colored = self.sregion vol = self.sregion.shape for v in range(0, vol[2] - 1): colored[:, :, v] = self.extent[4] + v * (-1) * abs(self.hdr['cdelt3']) field.image_data.point_data.add_array(colored.T.ravel()) field.image_data.point_data.get_array(1).name = 'color' field.update() field2 = mlab.pipeline.set_active_attribute(field, point_scalars='scalar') contour = mlab.pipeline.contour(field2) contour2 = mlab.pipeline.set_active_attribute(contour, point_scalars='color') mlab.pipeline.surface(contour2, colormap='jet', opacity=self.opacity) ## Insert a continuum plot if self.contfile != '': im = fits.open(self.contfile) dat = im[0].data ##dat0 = dat[0] channel = dat[0] region = np.swapaxes( channel[self.ystart:self.yend, self.xstart:self.xend] * 1000., 0, 1) field = mlab.contour3d(region, colormap='gist_ncar') field.contour.minimum_contour = 5 self.field = field2 self.field.scene.render() self.labels() mlab.view(azimuth=0, elevation=0, distance='auto') mlab.show() def _plotbutton2_fired(self): mlab.clf() self.loaddata() #field=mlab.contour3d(self.sregion,colormap='gist_ncar') # Generate a scalar field field = mlab.contour3d(self.sregion) # Generate a scalar field field.contour.maximum_contour = self.datamax field.contour.minimum_contour = self.datamin field.actor.property.opacity = self.opacity self.field = field self.labels() mlab.view(azimuth=0, elevation=0, distance='auto') mlab.show() def _plotbutton3_fired(self): mlab.clf() self.loaddata() field = mlab.pipeline.scalar_field( self.sregion) # Generate a scalar field mlab.pipeline.volume(field, vmax=self.datamax, vmin=self.datamin) self.field = field self.labels() mlab.view(azimuth=0, elevation=0, distance='auto') mlab.colorbar() mlab.show() # def _datamax_changed(self): # if hasattr(self, "field"): # self.field.contour.maximum_contour = self.datamax def _save_the_scene_fired(self): mlab.savefig(self.save_in_file) def _movie_fired(self): if os.path.exists("./tenpfigz"): print("The chance of you using this name is really small...") else: os.system("mkdir tenpfigz") if filter(os.path.isfile, glob.glob("./tenpfigz/*.jpg")) != []: os.system("rm -rf ./tenpfigz/*.jpg") i = 0 ## Quality of the movie: 0 is the worst, 8 is ok. self.field.scene.anti_aliasing_frames = self.quality self.field.scene.disable_render = True mlab.savefig('./tenpfigz/screenshot0' + str(i) + '.jpg') while i < (self.angle / 5): self.field.scene.camera.azimuth(5) self.field.scene.render() i += 1 if i < 10: mlab.savefig('./tenpfigz/screenshot0' + str(i) + '.jpg') elif 9 < i < 100: mlab.savefig('./tenpfigz/screenshot' + str(i) + '.jpg') self.field.scene.disable_render = False os.system("convert -delay " + str(self.delay) + " -loop " + str(self.iteration) + " ./tenpfigz/*.jpg ./tenpfigz/animation.gif") def _spin_fired(self): i = 0 self.field.scene.disable_render = True @mlab.animate def anim(): while i < 72: self.field.scene.camera.azimuth(5) self.field.scene.render() yield a = anim() #while i<72: # self.field.scene.camera.azimuth(5) # self.field.scene.render() # i += 1 # #mlab.savefig('./'+str(i)+'.png') self.field.scene.disable_render = False def _clearbutton_fired(self): mlab.clf()
class ParticleArrayHelper(HasTraits): """ This class manages a particle array and sets up the necessary plotting related information for it. """ # The particle array we manage. particle_array = Instance(ParticleArray) # The name of the particle array. name = Str # Current time. time = Float(0.0) # The active scalar to view. scalar = Str('rho', desc='name of the active scalar to view') # The mlab scalar plot for this particle array. plot = Instance(PipelineBase) # The mlab vectors plot for this particle array. plot_vectors = Instance(PipelineBase) # List of available scalars in the particle array. scalar_list = List(Str) scene = Instance(MlabSceneModel) # Sync'd trait with the scalar lut manager. show_legend = Bool(False, desc='if the scalar legend is to be displayed') # Show all scalars. list_all_scalars = Bool(False, desc='if all scalars should be listed') # Sync'd trait with the dataset to turn on/off visibility. visible = Bool(True, desc='if the particle array is to be displayed') # Show the time of the simulation on screen. show_time = Bool(False, desc='if the current time is displayed') # Edit the scalars. edit_scalars = Button('More options ...') # Show vectors. show_vectors = Bool(False, desc='if vectors should be displayed') vectors = Str('u, v, w', enter_set=True, auto_set=False, desc='the vectors to display') mask_on_ratio = Int(3, desc='mask one in specified points') scale_factor = Float(1.0, desc='scale factor for vectors', enter_set=True, auto_set=False) edit_vectors = Button('More options ...') # Private attribute to store the Text module. _text = Instance(PipelineBase) # Extra scalars to show. These will be added and saved to the data if # needed. extra_scalars = List(Str) # Set to True when the particle array is updated with a new property say. updated = Event # Private attribute to store old value of visibility in case of empty # arrays. _old_visible = Bool(True) ######################################## # View related code. view = View( Item(name='name', show_label=False, editor=TitleEditor()), Group(Group( Group( Item(name='visible'), Item(name='show_legend'), Item(name='scalar', editor=EnumEditor(name='scalar_list')), Item(name='list_all_scalars'), Item(name='show_time'), columns=2, ), Item(name='edit_scalars', show_label=False), label='Scalars', ), Group( Item(name='show_vectors'), Item(name='vectors'), Item(name='mask_on_ratio'), Item(name='scale_factor'), Item(name='edit_vectors', show_label=False), label='Vectors', ), layout='tabbed')) # Private protocol ############################################ def _add_vmag(self, pa): if 'vmag' not in pa.properties: if 'vmag2' in pa.output_property_arrays: vmag = numpy.sqrt(pa.vmag2) else: vmag = numpy.sqrt(pa.u**2 + pa.v**2 + pa.w**2) pa.add_property(name='vmag', data=vmag) if len(pa.output_property_arrays) > 0: # We do not call add_output_arrays when the default is empty # as if it is empty, all arrays are saved anyway. However, # adding just vmag in this case will mean that when the # particle array is saved it will only save vmag! This is # not what we want, hence we add vmag *only* if the # output_property_arrays is non-zero length. pa.add_output_arrays(['vmag']) self.updated = True def _get_scalar(self, pa, scalar): """Return the requested scalar from the given particle array. """ if scalar in self.extra_scalars: method_name = '_add_' + scalar method = getattr(self, method_name) method(pa) return getattr(pa, scalar) # Traits handlers ############################################# def _edit_scalars_fired(self): self.plot.edit_traits() def _edit_vectors_fired(self): self.plot_vectors.edit_traits() def _particle_array_changed(self, old, pa): self.name = pa.name self._list_all_scalars_changed(self.list_all_scalars) # Update the plot. x, y, z = pa.x, pa.y, pa.z s = self._get_scalar(pa, self.scalar) p = self.plot mlab = self.scene.mlab empty = len(x) == 0 old_empty = len(old.x) == 0 if old is not None else True if p is None and not empty: src = mlab.pipeline.scalar_scatter(x, y, z, s) p = mlab.pipeline.glyph(src, mode='point', scale_mode='none') p.actor.property.point_size = 3 scm = p.module_manager.scalar_lut_manager scm.set(show_legend=self.show_legend, use_default_name=False, data_name=self.scalar) self.sync_trait('visible', p, mutual=True) self.sync_trait('show_legend', scm, mutual=True) # set_arrays(p.mlab_source.m_data, pa) self.plot = p elif not empty: if len(x) == len(p.mlab_source.x): p.mlab_source.set(x=x, y=y, z=z, scalars=s) if self.plot_vectors: self._vectors_changed(self.vectors) else: if self.plot_vectors: u, v, w = self._get_vectors_for_plot(self.vectors) p.mlab_source.reset(x=x, y=y, z=z, scalars=s, u=u, v=v, w=w) else: p.mlab_source.reset(x=x, y=y, z=z, scalars=s) p.mlab_source.update() if empty and not old_empty: if p is not None: src = p.parent.parent self._old_visible = src.visible src.visible = False if old_empty and not empty: if p is not None: p.parent.parent.visible = self._old_visible self._show_vectors_changed(self.show_vectors) # Setup the time. self._show_time_changed(self.show_time) def _scalar_changed(self, value): p = self.plot if p is not None: p.mlab_source.scalars = self._get_scalar(self.particle_array, value) p.module_manager.scalar_lut_manager.data_name = value def _list_all_scalars_changed(self, list_all_scalars): pa = self.particle_array if list_all_scalars: sc_list = pa.properties.keys() self.scalar_list = sorted(set(sc_list + self.extra_scalars)) else: if len(pa.output_property_arrays) > 0: self.scalar_list = sorted( set(pa.output_property_arrays + self.extra_scalars)) else: sc_list = pa.properties.keys() self.scalar_list = sorted(set(sc_list + self.extra_scalars)) def _show_time_changed(self, value): txt = self._text mlab = self.scene.mlab if value: if txt is not None: txt.visible = True elif self.plot is not None: mlab.get_engine().current_object = self.plot txt = mlab.text(0.01, 0.01, 'Time = 0.0', width=0.35) self._text = txt self._time_changed(self.time) else: if txt is not None: txt.visible = False def _get_vectors_for_plot(self, vectors): pa = self.particle_array comps = [x.strip() for x in vectors.split(',')] if len(comps) == 3: try: vec = tuple(getattr(pa, x) for x in comps) except AttributeError: return None else: return vec def _vectors_changed(self, value): vec = self._get_vectors_for_plot(value) if vec is not None: self.plot.mlab_source.set(vectors=numpy.c_[vec[0], vec[1], vec[2]]) def _show_vectors_changed(self, value): pv = self.plot_vectors if pv is not None: pv.visible = value elif self.plot is not None and value: self._vectors_changed(self.vectors) pv = self.scene.mlab.pipeline.vectors( self.plot.mlab_source.m_data, mask_points=self.mask_on_ratio, scale_factor=self.scale_factor) self.plot_vectors = pv def _mask_on_ratio_changed(self, value): pv = self.plot_vectors if pv is not None: pv.glyph.mask_points.on_ratio = value def _scale_factor_changed(self, value): pv = self.plot_vectors if pv is not None: pv.glyph.glyph.scale_factor = value def _time_changed(self, value): txt = self._text if txt is not None: txt.text = 'Time = %.3e' % (value) def _extra_scalars_default(self): return ['vmag']
class CSVGrapher(Loggable): open_button = Button('Open') plot_button = Button('Plot') as_series = Bool(False) _path = Str path = File data_selectors = List column_names = List stats = Str file_name = Property(depends_on='_path') short_name = Property(depends_on='_path') _graph_count = 0 delimiter = Str(',') def quick_graph(self, p): kind = 'scatter' # for det in ['H2']: for det in ['H2', 'H1', 'AX', 'L1', 'L2']: g = self._gc(p, det, kind) info = g.edit_traits() g.save_pdf('/Users/argonlab2/Sandbox/baselines/auto-down50/{}_obama{}'.format(kind, det)) # info.dispose() def _gc(self, p, det, kind): g = Graph(container_dict=dict(padding=5), window_width=1000, window_height=800, window_x=40, window_y=20 ) with open(p, 'r') as rfile: # gather data reader = csv.reader(rfile) header = reader.next() groups = self._parse_data(reader) ''' groups= [data,] data shape = nrow,ncols ''' data = groups[0] x = data[0] y = data[header.index(det)] sy = smooth(y, window_len=120) # , window='flat') x = x[::50] y = y[::50] sy = sy[::50] # smooth # plot g.new_plot(zoom=True, xtitle='Time (s)', ytitle='{} Baseline Intensity (fA)'.format(det)) g.new_series(x, y, type=kind, marker='dot', marker_size=2) g.new_series(x, sy, line_width=2) # g.set_x_limits(500, 500 + 60 * 30) # g.edit_traits() return g def add_column_selector(self): csnames = self.column_names vind = len(self.data_selectors) + 1 try: cs = self.data_selectors[-1].clone_traits(['fit', 'plot_type', 'parent']) cs.trait_set(index=csnames[0], value=csnames[vind], column_names=csnames) self.data_selectors.append(cs) except IndexError: self.warning_dialog('No More Columns') def remove_column_selector(self, cs): self.data_selectors.remove(cs) # =============================================================================== # handlers # =============================================================================== def _open_button_fired(self): self.data_selectors = [] # p = '/Users/ross/Sandbox/csvdata.txt' # self._path = p # self._path=os.path.join(paths.data_dir,'spectrometer_scans','scan007.txt') dlg = FileDialog(action='open', default_directory=paths.data_dir) if dlg.open() == OK: self._path = dlg.path with open(self._path, 'U') as rfile: reader = csv.reader(rfile, delimiter=self.delimiter) self.column_names = names = reader.next() try: cs = DataSelector(column_names=names, index=names[0], value=names[1], removable=False, parent=self, ) self.data_selectors.append(cs) except IndexError: self.warning_dialog('Invalid delimiter {} for {}'.format(DELIMITERS[self.delimiter], os.path.basename(self._path) )) def _parse_data(self, reader): groups = [] while 1: lines = [] for l in reader: l = [li.strip() for li in l] if not l or not any(l): data = np.array([map(float, l) for l in lines]) data = data.transpose() groups.append(data) break lines.append(l) else: # print lines # for l in lines: # print l # print map(float, l) # data = np.array([map(float, l) for l in lines]) data = data.transpose() groups.append(data) break return groups def _plot_button_fired(self): with open(self._path, 'U') as rfile: reader = csv.reader(rfile, delimiter=self.delimiter) _header = reader.next() groups = self._parse_data(reader) # print groups for data in groups: # print data self._show_plot(data) def _show_plot(self, data): cd = dict(padding=5, stack_order='top_to_bottom') csnames = self.column_names xmin = np.Inf xmax = -np.Inf if self.as_series: g = RegressionGraph(container_dict=cd) p = g.new_plot(padding=[50, 5, 5, 50], xtitle='' ) p.value_range.tight_bounds = False p.value_range.margin = 0.1 else: g = StackedRegressionGraph(container_dict=cd) regressable = False # metadata = None for i, csi in enumerate(self.data_selectors): if not self.as_series: p = g.new_plot(padding=[50, 5, 5, 50]) p.value_range.tight_bounds = False p.value_range.margin = 0.1 plotid = i else: plotid = 0 try: x = data[csnames.index(csi.index)] y = data[csnames.index(csi.value)] xmin = min(xmin, min(x)) xmax = max(xmax, max(x)) fit = csi.fit if csi.fit != NULL_STR else None g.new_series(x, y, fit=fit, filter_outliers=csi.use_filter, type=csi.plot_type, plotid=plotid) g.set_x_title(csi.index, plotid=plotid) g.set_y_title(csi.value, plotid=plotid) if fit: regressable = True except IndexError: pass g.set_x_limits(xmin, xmax, pad='0.1') self._graph_count += 1 if regressable: gg = StatsGraph(graph=g) gii = gg else: gii = g g._update_graph() def show(gi): gi.window_title = '{} Graph {}'.format(self.short_name, self._graph_count) gi.window_x = self._graph_count * 20 + 400 gi.window_y = self._graph_count * 20 + 20 gi.edit_traits() show(gii) # =============================================================================== # property get/set # =============================================================================== def _get_file_name(self): if os.path.isfile(self._path): return os.path.relpath(self._path, paths.data_dir) else: return '' def _get_short_name(self): if os.path.isfile(self._path): return os.path.basename(self._path) else: return '' # =============================================================================== # views # =============================================================================== def traits_view(self): v = View(Item('as_series'), Item('delimiter', editor=EnumEditor(values=DELIMITERS)), HGroup(Item('open_button', show_label=False), Item('plot_button', enabled_when='_path', show_label=False), Item('file_name', show_label=False, style='readonly')), Item('data_selectors', show_label=False, editor=ListEditor(mutable=False, style='custom', editor=InstanceEditor())), resizable=True, width=525, height=225, title='CSV Plotter' ) return v
class MayaviViewer(HasTraits): """ This class represents a Mayavi based viewer for the particles. They are queried from a running solver. """ particle_arrays = List(Instance(ParticleArrayHelper), []) pa_names = List(Str, []) interpolator = Instance(InterpolatorView) # The default scalar to load up when running the viewer. scalar = Str("rho") scene = Instance(MlabSceneModel, ()) ######################################## # Traits to pull data from a live solver. live_mode = Bool(False, desc='if data is obtained from a running solver ' 'or from saved files') shell = Button('Launch Python Shell') host = Str('localhost', desc='machine to connect to') port = Int(8800, desc='port to use to connect to solver') authkey = Password('pysph', desc='authorization key') host_changed = Bool(True) client = Instance(MultiprocessingClient) controller = Property(depends_on='live_mode, host_changed') ######################################## # Traits to view saved solver output. files = List(Str, []) directory = Directory() current_file = Str('', desc='the file being viewed currently') update_files = Button('Refresh') file_count = Range(low='_low', high='_n_files', value=0, desc='the file counter') play = Bool(False, desc='if all files are played automatically') play_delay = Float(0.2, desc='the delay between loading files') loop = Bool(False, desc='if the animation is looped') # This is len(files) - 1. _n_files = Int(0) _low = Int(0) ######################################## # Timer traits. timer = Instance(Timer) interval = Range(0.5, 20.0, 2.0, desc='frequency in seconds with which plot is updated') ######################################## # Solver info/control. current_time = Float(0.0, desc='the current time in the simulation') time_step = Float(0.0, desc='the time-step of the solver') iteration = Int(0, desc='the current iteration number') pause_solver = Bool(False, desc='if the solver should be paused') ######################################## # Movie. record = Bool(False, desc='if PNG files are to be saved for animation') frame_interval = Range(1, 100, 5, desc='the interval between screenshots') movie_directory = Str # internal counters. _count = Int(0) _frame_count = Int(0) _last_time = Float _solver_data = Any _file_name = Str _particle_array_updated = Bool ######################################## # The layout of the dialog created view = View(HSplit( Group( Group(Group( Item(name='directory'), Item(name='current_file'), Item(name='file_count'), HGroup(Item(name='play'), Item(name='play_delay', label='Delay', resizable=True), Item(name='loop'), Item(name='update_files', show_label=False), padding=0), padding=0, label='Saved Data', selected=True, enabled_when='not live_mode', ), Group( Item(name='live_mode'), Group( Item(name='host'), Item(name='port'), Item(name='authkey'), enabled_when='live_mode', ), label='Connection', ), layout='tabbed'), Group( Group( Item(name='current_time'), Item(name='time_step'), Item(name='iteration'), Item(name='pause_solver', enabled_when='live_mode'), Item(name='interval', enabled_when='not live_mode'), label='Solver', ), Group( Item(name='record'), Item(name='frame_interval'), Item(name='movie_directory'), label='Movie', ), layout='tabbed', ), Group(Item(name='particle_arrays', style='custom', show_label=False, editor=ListEditor(use_notebook=True, deletable=False, page_name='.name')), Item(name='interpolator', style='custom', show_label=False), layout='tabbed'), Item(name='shell', show_label=False), ), Group( Item('scene', editor=SceneEditor(scene_class=MayaviScene), height=400, width=600, show_label=False), )), resizable=True, title='PySPH Particle Viewer', height=640, width=1024, handler=ViewerHandler) ###################################################################### # `MayaviViewer` interface. ###################################################################### def on_close(self): self._handle_particle_array_updates() @on_trait_change('scene:activated') def start_timer(self): if not self.live_mode: # No need for the timer if we are rendering files. return # Just accessing the timer will start it. t = self.timer if not t.IsRunning(): t.Start(int(self.interval * 1000)) @on_trait_change('scene:activated') def update_plot(self): # No need to do this if files are being used. if not self.live_mode: return # do not update if solver is paused if self.pause_solver: return if self.client is None: self.host_changed = True controller = self.controller if controller is None: return self.current_time = t = controller.get_t() self.time_step = controller.get_dt() self.iteration = controller.get_count() arrays = [] for idx, name in enumerate(self.pa_names): pa = controller.get_named_particle_array(name) arrays.append(pa) pah = self.particle_arrays[idx] pah.set(particle_array=pa, time=t) self.interpolator.particle_arrays = arrays if self.record: self._do_snap() def run_script(self, path): """Execute a script in the namespace of the viewer. """ with open(path) as fp: data = fp.read() ns = self._get_shell_namespace() exec(compile(data, path, 'exec'), ns) ###################################################################### # Private interface. ###################################################################### def _do_snap(self): """Generate the animation.""" p_arrays = self.particle_arrays if len(p_arrays) == 0: return if self.current_time == self._last_time: return if len(self.movie_directory) == 0: controller = self.controller output_dir = controller.get_output_directory() movie_dir = os.path.join(output_dir, 'movie') self.movie_directory = movie_dir else: movie_dir = self.movie_directory if not os.path.exists(movie_dir): os.mkdir(movie_dir) interval = self.frame_interval count = self._count if count % interval == 0: fname = 'frame%06d.png' % (self._frame_count) p_arrays[0].scene.save_png(os.path.join(movie_dir, fname)) self._frame_count += 1 self._last_time = self.current_time self._count += 1 @on_trait_change('host,port,authkey') def _mark_reconnect(self): if self.live_mode: self.host_changed = True @cached_property def _get_controller(self): ''' get the controller, also sets the iteration count ''' if not self.live_mode: return None reconnect = self.host_changed if not reconnect: try: c = self.client.controller except Exception as e: logger.info('Error: no connection or connection closed: ' 'reconnecting: %s' % e) reconnect = True self.client = None else: try: self.client.controller.get_count() except IOError: self.client = None reconnect = True if reconnect: self.host_changed = False try: if MultiprocessingClient.is_available((self.host, self.port)): self.client = MultiprocessingClient(address=(self.host, self.port), authkey=self.authkey) else: logger.info('Could not connect: Multiprocessing Interface' ' not available on %s:%s' % (self.host, self.port)) return None except Exception as e: logger.info('Could not connect: check if solver is ' 'running:%s' % e) return None c = self.client.controller self.iteration = c.get_count() if self.client is None: return None else: return self.client.controller def _client_changed(self, old, new): if not self.live_mode: return self._clear() if new is None: return else: self.pa_names = self.client.controller.get_particle_array_names() self.particle_arrays = [ self._make_particle_array_helper(self.scene, x) for x in self.pa_names ] self.interpolator = InterpolatorView(scene=self.scene) # Turn on the legend for the first particle array. if len(self.particle_arrays) > 0: self.particle_arrays[0].set(show_legend=True, show_time=True) def _timer_event(self): # catch all Exceptions else timer will stop try: self.update_plot() except Exception as e: logger.info('Exception: %s caught in timer_event' % e) def _interval_changed(self, value): t = self.timer if t is None: return if t.IsRunning(): t.Stop() t.Start(int(value * 1000)) def _timer_default(self): return Timer(int(self.interval * 1000), self._timer_event) def _pause_solver_changed(self, value): if self.live_mode: c = self.controller if c is None: return if value: c.pause_on_next() else: c.cont() def _record_changed(self, value): if value: self._do_snap() def _files_changed(self, value): if len(value) == 0: return else: d = os.path.dirname(os.path.abspath(value[0])) self.movie_directory = os.path.join(d, 'movie') self.set(directory=d, trait_change_notify=False) self._n_files = len(value) - 1 self._frame_count = 0 self._count = 0 self.frame_interval = 1 fc = self.file_count self.file_count = 0 if fc == 0: # Force an update when our original file count is 0. self._file_count_changed(fc) t = self.timer if not self.live_mode: if t.IsRunning(): t.Stop() else: if not t.IsRunning(): t.Stop() t.Start(self.interval * 1000) def _file_count_changed(self, value): # Save out any updates for the previous file if needed. self._handle_particle_array_updates() # Load the new file. fname = self.files[value] self._file_name = fname self.current_file = os.path.basename(fname) # Code to read the file, create particle array and setup the helper. data = load(fname) solver_data = data["solver_data"] arrays = data["arrays"] self._solver_data = solver_data self.current_time = t = float(solver_data['t']) self.time_step = float(solver_data['dt']) self.iteration = int(solver_data['count']) names = list(arrays.keys()) pa_names = self.pa_names if len(pa_names) == 0: self.interpolator = InterpolatorView(scene=self.scene) self.pa_names = names pas = [] for name in names: pa = arrays[name] pah = self._make_particle_array_helper(self.scene, name) # Must set this after setting the scene. pah.set(particle_array=pa, time=t) pas.append(pah) self.particle_arrays = pas else: for idx, name in enumerate(pa_names): pa = arrays[name] pah = self.particle_arrays[idx] pah.set(particle_array=pa, time=t) self.interpolator.particle_arrays = list(arrays.values()) if self.record: self._do_snap() def _loop_changed(self, value): if value and self.play: self._play_changed(self.play) def _play_changed(self, value): t = self.timer if value: t.Stop() t.callable = self._play_event t.Start(1000 * self.play_delay) else: t.Stop() t.callable = self._timer_event def _clear(self): self.pa_names = [] self.scene.mayavi_scene.children[:] = [] def _play_event(self): nf = self._n_files pc = self.file_count pc += 1 if pc > nf: if self.loop: pc = 0 else: self.timer.Stop() pc = nf self.file_count = pc self._handle_particle_array_updates() def _play_delay_changed(self): if self.play: self._play_changed(self.play) def _scalar_changed(self, value): for pa in self.particle_arrays: pa.scalar = value def _update_files_fired(self): fc = self.file_count files = glob_files(self.files[fc]) sort_file_list(files) self.files = files self.file_count = fc if self.play: self._play_changed(self.play) def _shell_fired(self): ns = self._get_shell_namespace() obj = PythonShellView(ns=ns) obj.edit_traits() def _get_shell_namespace(self): return dict(viewer=self, particle_arrays=self.particle_arrays, interpolator=self.interpolator, scene=self.scene, mlab=self.scene.mlab) def _directory_changed(self, d): ext = os.path.splitext(self.files[-1])[1] files = glob.glob(os.path.join(d, '*' + ext)) if len(files) > 0: self._clear() sort_file_list(files) self.files = files self.file_count = min(self.file_count, len(files)) else: pass def _live_mode_changed(self, value): if value: self._file_name = '' self.client = None self._clear() self._mark_reconnect() self.start_timer() else: self.client = None self._clear() self.timer.Stop() def _particle_array_helper_updated(self, value): self._particle_array_updated = True def _handle_particle_array_updates(self): # Called when the particle array helper fires an updated event. if self._particle_array_updated and self._file_name: sd = self._solver_data arrays = [x.particle_array for x in self.particle_arrays] dump(self._file_name, arrays, sd) self._particle_array_updated = False def _make_particle_array_helper(self, scene, name): pah = ParticleArrayHelper(scene=scene, name=name, scalar=self.scalar) pah.on_trait_change(self._particle_array_helper_updated, 'updated') return pah
class Demo(HasTraits): list1 = List(Int) list2 = List(Float) list3 = List(Str, maxlen=3) list4 = List(Enum('red', 'green', 'blue', 2, 3)) list5 = List(Range(low=0.0, high=10.0)) # 'low' and 'high' are used to demonstrate lists containing dynamic ranges. low = Float(0.0) high = Float(1.0) list6 = List(Range(low=-1.0, high='high')) list7 = List(Range(low='low', high='high')) pop1 = Button("Pop from first list") sort1 = Button("Sort first list") # This will be str(self.list1). list1str = Property(Str, depends_on='list1') traits_view = \ View( HGroup( # This VGroup forms the column of CSVListEditor examples. VGroup( Item('list1', label="List(Int)", editor=CSVListEditor(ignore_trailing_sep=False), tooltip='options: ignore_trailing_sep=False'), Item('list1', label="List(Int)", style='readonly', editor=CSVListEditor()), Item('list2', label="List(Float)", editor=CSVListEditor(enter_set=True, auto_set=False), tooltip='options: enter_set=True, auto_set=False'), Item('list3', label="List(Str, maxlen=3)", editor=CSVListEditor()), Item('list4', label="List(Enum('red', 'green', 'blue', 2, 3))", editor=CSVListEditor(sep=None), tooltip='options: sep=None'), Item('list5', label="List(Range(low=0.0, high=10.0))", editor=CSVListEditor()), Item('list6', label="List(Range(low=-1.0, high='high'))", editor=CSVListEditor()), Item('list7', label="List(Range(low='low', high='high'))", editor=CSVListEditor()), springy=True, ), # This VGroup forms the right column; it will display the # Python str representation of the lists. VGroup( UItem('list1str', editor=TextEditor(), enabled_when='False', width=240), UItem('list1str', editor=TextEditor(), enabled_when='False', width=240), UItem('list2', editor=TextEditor(), enabled_when='False', width=240), UItem('list3', editor=TextEditor(), enabled_when='False', width=240), UItem('list4', editor=TextEditor(), enabled_when='False', width=240), UItem('list5', editor=TextEditor(), enabled_when='False', width=240), UItem('list6', editor=TextEditor(), enabled_when='False', width=240), UItem('list7', editor=TextEditor(), enabled_when='False', width=240), ), ), '_', HGroup('low', 'high', spring, UItem('pop1'), UItem('sort1')), Heading("Notes"), Label("Hover over a list to see which editor options are set, " "if any."), Label("The editor of the first list, List(Int), uses " "ignore_trailing_sep=False, so a trailing comma is " "an error."), Label("The second list is a read-only view of the first list."), Label("The editor of the List(Float) example has enter_set=True " "and auto_set=False; press Enter to validate."), Label("The List(Str) example will accept at most 3 elements."), Label("The editor of the List(Enum(...)) example uses sep=None, " "i.e. whitespace acts as a separator."), Label("The last two List(Range(...)) examples take one or both " "of their limits from the Low and High fields below."), width=720, title="CSVListEditor Demonstration", ) def _list1_default(self): return [1, 4, 0, 10] def _get_list1str(self): return str(self.list1) def _pop1_fired(self): if len(self.list1) > 0: x = self.list1.pop() print(x) def _sort1_fired(self): self.list1.sort()
class Kit2FiffPanel(HasPrivateTraits): """Control panel for kit2fiff conversion""" model = Instance(Kit2FiffModel) # model copies for view use_mrk = DelegatesTo('model') sqd_file = DelegatesTo('model') hsp_file = DelegatesTo('model') fid_file = DelegatesTo('model') stim_chs = DelegatesTo('model') stim_chs_manual = DelegatesTo('model') stim_slope = DelegatesTo('model') # info can_save = DelegatesTo('model') sqd_fname = DelegatesTo('model') hsp_fname = DelegatesTo('model') fid_fname = DelegatesTo('model') # Source Files reset_dig = Button # Visualization scene = Instance(MlabSceneModel) fid_obj = Instance(PointObject) elp_obj = Instance(PointObject) hsp_obj = Instance(PointObject) # Output save_as = Button(label='Save FIFF...') clear_all = Button(label='Clear All') queue = Instance(queue.Queue, ()) queue_feedback = Str('') queue_current = Str('') queue_len = Int(0) queue_len_str = Property(Str, depends_on=['queue_len']) error = Str('') view = View( VGroup( VGroup(Item('sqd_file', label="Data"), Item('sqd_fname', show_label=False, style='readonly'), Item('hsp_file', label='Dig Head Shape'), Item('hsp_fname', show_label=False, style='readonly'), Item('fid_file', label='Dig Points'), Item('fid_fname', show_label=False, style='readonly'), Item('reset_dig', label='Clear Digitizer Files', show_label=False), Item('use_mrk', editor=use_editor, style='custom'), label="Sources", show_border=True), VGroup(Item('stim_slope', label="Event Onset", style='custom', editor=EnumEditor(values={ '+': '2:Peak (0 to 5 V)', '-': '1:Trough (5 to 0 V)' }, cols=2), help="Whether events are marked by a decrease " "(trough) or an increase (peak) in trigger " "channel values"), Item('stim_chs', label="Binary Coding", style='custom', editor=EnumEditor(values={ '>': '1:1 ... 128', '<': '3:128 ... 1', 'man': '2:Manual' }, cols=2), help="Specifies the bit order in event " "channels. Assign the first bit (1) to the " "first or the last trigger channel."), Item('stim_chs_manual', label='Stim Channels', style='custom', visible_when="stim_chs == 'man'"), label='Events', show_border=True), HGroup(Item('save_as', enabled_when='can_save'), spring, 'clear_all', show_labels=False), Item('queue_feedback', show_label=False, style='readonly'), Item('queue_current', show_label=False, style='readonly'), Item('queue_len_str', show_label=False, style='readonly'), )) def __init__(self, *args, **kwargs): super(Kit2FiffPanel, self).__init__(*args, **kwargs) # setup save worker def worker(): while True: raw, fname = self.queue.get() basename = os.path.basename(fname) self.queue_len -= 1 self.queue_current = 'Processing: %s' % basename # task try: raw.save(fname, overwrite=True) except Exception as err: self.error = str(err) res = "Error saving: %s" else: res = "Saved: %s" # finalize self.queue_current = '' self.queue_feedback = res % basename self.queue.task_done() t = Thread(target=worker) t.daemon = True t.start() # setup mayavi visualization m = self.model self.fid_obj = PointObject(scene=self.scene, color=(25, 225, 25), point_scale=5e-3) m.sync_trait('fid', self.fid_obj, 'points', mutual=False) m.sync_trait('head_dev_trans', self.fid_obj, 'trans', mutual=False) self.elp_obj = PointObject(scene=self.scene, color=(50, 50, 220), point_scale=1e-2, opacity=.2) m.sync_trait('elp', self.elp_obj, 'points', mutual=False) m.sync_trait('head_dev_trans', self.elp_obj, 'trans', mutual=False) self.hsp_obj = PointObject(scene=self.scene, color=(200, 200, 200), point_scale=2e-3) m.sync_trait('hsp', self.hsp_obj, 'points', mutual=False) m.sync_trait('head_dev_trans', self.hsp_obj, 'trans', mutual=False) self.scene.camera.parallel_scale = 0.15 self.scene.mlab.view(0, 0, .15) def _clear_all_fired(self): self.model.clear_all() @cached_property def _get_queue_len_str(self): if self.queue_len: return "Queue length: %i" % self.queue_len else: return '' def _reset_dig_fired(self): self.reset_traits(['hsp_file', 'fid_file']) def _save_as_fired(self): # create raw try: raw = self.model.get_raw() except Exception as err: error(None, str(err), "Error Creating KIT Raw") raise # find default path stem, _ = os.path.splitext(self.sqd_file) if not stem.endswith('raw'): stem += '-raw' default_path = stem + '.fif' # save as dialog dlg = FileDialog(action="save as", wildcard="fiff raw file (*.fif)|*.fif", default_path=default_path) dlg.open() if dlg.return_code != OK: return fname = dlg.path if not fname.endswith('.fif'): fname += '.fif' if os.path.exists(fname): answer = confirm( None, "The file %r already exists. Should it " "be replaced?", "Overwrite File?") if answer != YES: return self.queue.put((raw, fname)) self.queue_len += 1
class TableFilterEditor(HasTraits): """ An editor that manages table filters. """ # ------------------------------------------------------------------------- # Trait definitions: # ------------------------------------------------------------------------- #: TableEditor this editor is associated with editor = Instance(TableEditor) #: The list of filters filters = List(TableFilter) #: The list of available templates from which filters can be created templates = Property(List(TableFilter), observe="filters") #: The currently selected filter template selected_template = Instance(TableFilter) #: The currently selected filter selected_filter = Instance(TableFilter, allow_none=True) #: The view to use for the current filter selected_filter_view = Property(observe="selected_filter") #: Buttons for add/removing filters add_button = Button("New") remove_button = Button("Delete") # The default view for this editor view = View( Group( Group( Group( Item("add_button", enabled_when="selected_template"), Item( "remove_button", enabled_when="len(templates) > 1 and " "selected_filter is not None", ), orientation="horizontal", show_labels=False, ), Label("Base filter for new filters:"), Item("selected_template", editor=EnumEditor(name="templates")), Item( "selected_filter", style="custom", editor=EnumEditor(name="filters", mode="list"), ), show_labels=False, ), Item( "selected_filter", width=0.75, style="custom", editor=InstanceEditor(view_name="selected_filter_view"), ), id="TableFilterEditorSplit", show_labels=False, layout="split", orientation="horizontal", ), id="traitsui.qt4.table_editor.TableFilterEditor", buttons=["OK", "Cancel"], kind="livemodal", resizable=True, width=800, height=400, title="Customize filters", ) # ------------------------------------------------------------------------- # Private methods: # ------------------------------------------------------------------------- # -- Trait Property getter/setters ---------------------------------------- @cached_property def _get_selected_filter_view(self): view = None if self.selected_filter: model = self.editor.model index = model.mapToSource(model.index(0, 0)) if index.isValid(): obj = self.editor.items()[index.row()] else: obj = None view = self.selected_filter.edit_view(obj) return view @cached_property def _get_templates(self): templates = [f for f in self.editor.factory.filters if f.template] templates.extend(self.filters) return templates # -- Trait Change Handlers ------------------------------------------------ def _editor_changed(self): self.filters = [ f.clone_traits() for f in self.editor.factory.filters if not f.template ] self.selected_template = self.templates[0] def _add_button_fired(self): """ Create a new filter based on the selected template and select it. """ new_filter = self.selected_template.clone_traits() new_filter.template = False new_filter.name = new_filter._name = "New filter" self.filters.append(new_filter) self.selected_filter = new_filter def _remove_button_fired(self): """ Delete the currently selected filter. """ if self.selected_template == self.selected_filter: self.selected_template = self.templates[0] index = self.filters.index(self.selected_filter) del self.filters[index] if index < len(self.filters): self.selected_filter = self.filters[index] else: self.selected_filter = None @observe("selected_filter:name") def _update_filter_list(self, event): """ A hack to make the EnumEditor watching the list of filters refresh their text when the name of the selected filter changes. """ filters = self.filters self.filters = [] self.filters = filters
class BaselineView(HasTraits): # This mapping should match the flag definitions in libsbp for # the MsgBaselineNED message. While this isn't strictly necessary # it helps avoid confusion python_console_cmds = Dict() table = List() logging_b = Bool(False) directory_name_b = File plot = Instance(Plot) plot_data = Instance(ArrayPlotData) running = Bool(True) zoomall = Bool(False) position_centered = Bool(False) clear_button = SVGButton( label='', tooltip='Clear', filename=resource_filename('console/images/iconic/x.svg'), width=16, height=16) zoomall_button = SVGButton( label='', tooltip='Zoom All', toggle=True, filename=resource_filename('console/images/iconic/fullscreen.svg'), width=16, height=16) center_button = SVGButton( label='', tooltip='Center on Baseline', toggle=True, filename=resource_filename('console/images/iconic/target.svg'), width=16, height=16) paused_button = SVGButton( label='', tooltip='Pause', toggle_tooltip='Run', toggle=True, filename=resource_filename('console/images/iconic/pause.svg'), toggle_filename=resource_filename('console/images/iconic/play.svg'), width=16, height=16) reset_button = Button(label='Reset Filters') traits_view = View( HSplit( Item( 'table', style='readonly', editor=TabularEditor(adapter=SimpleAdapter()), show_label=False, width=0.3), VGroup( HGroup( Item('paused_button', show_label=False), Item('clear_button', show_label=False), Item('zoomall_button', show_label=False), Item('center_button', show_label=False), Item('reset_button', show_label=False), ), Item( 'plot', show_label=False, editor=ComponentEditor(bgcolor=(0.8, 0.8, 0.8)), )))) def _zoomall_button_fired(self): self.zoomall = not self.zoomall def _center_button_fired(self): self.position_centered = not self.position_centered def _paused_button_fired(self): self.running = not self.running def _reset_button_fired(self): self.link(MsgResetFilters(filter=0)) def _reset_remove_current(self): self.plot_data.set_data('cur_fixed_n', []) self.plot_data.set_data('cur_fixed_e', []) self.plot_data.set_data('cur_fixed_d', []) self.plot_data.set_data('cur_float_n', []) self.plot_data.set_data('cur_float_e', []) self.plot_data.set_data('cur_float_d', []) self.plot_data.set_data('cur_dgnss_n', []) self.plot_data.set_data('cur_dgnss_e', []) self.plot_data.set_data('cur_dgnss_d', []) def _clear_history(self): self.plot_data.set_data('n_fixed', []) self.plot_data.set_data('e_fixed', []) self.plot_data.set_data('d_fixed', []) self.plot_data.set_data('n_float', []) self.plot_data.set_data('e_float', []) self.plot_data.set_data('d_float', []) self.plot_data.set_data('n_dgnss', []) self.plot_data.set_data('e_dgnss', []) self.plot_data.set_data('d_dgnss', []) def _clear_button_fired(self): self.n[:] = np.NAN self.e[:] = np.NAN self.d[:] = np.NAN self.mode[:] = np.NAN self.plot_data.set_data('t', []) self._clear_history() self._reset_remove_current() def iar_state_callback(self, sbp_msg, **metadata): self.num_hyps = sbp_msg.num_hyps self.last_hyp_update = time.time() def age_corrections_callback(self, sbp_msg, **metadata): age_msg = MsgAgeCorrections(sbp_msg) if age_msg.age != 0xFFFF: self.age_corrections = age_msg.age / 10.0 else: self.age_corrections = None def gps_time_callback(self, sbp_msg, **metadata): if sbp_msg.msg_type == SBP_MSG_GPS_TIME_DEP_A: time_msg = MsgGPSTimeDepA(sbp_msg) flags = 1 elif sbp_msg.msg_type == SBP_MSG_GPS_TIME: time_msg = MsgGPSTime(sbp_msg) flags = time_msg.flags if flags != 0: self.week = time_msg.wn self.nsec = time_msg.ns_residual def utc_time_callback(self, sbp_msg, **metadata): tmsg = MsgUtcTime(sbp_msg) seconds = math.floor(tmsg.seconds) microseconds = int(tmsg.ns / 1000.00) if tmsg.flags & 0x1 == 1: dt = datetime.datetime(tmsg.year, tmsg.month, tmsg.day, tmsg.hours, tmsg.minutes, tmsg.seconds, microseconds) self.utc_time = dt self.utc_time_flags = tmsg.flags if (tmsg.flags >> 3) & 0x3 == 0: self.utc_source = "Factory Default" elif (tmsg.flags >> 3) & 0x3 == 1: self.utc_source = "Non Volatile Memory" elif (tmsg.flags >> 3) & 0x3 == 2: self.utc_source = "Decoded this Session" else: self.utc_source = "Unknown" else: self.utc_time = None self.utc_source = None def baseline_heading_callback(self, sbp_msg, **metadata): headingMsg = MsgBaselineHeading(sbp_msg) if headingMsg.flags & 0x7 != 0: self.heading = headingMsg.heading * 1e-3 else: self.heading = "---" def baseline_callback(self, sbp_msg, **metadata): soln = MsgBaselineNEDDepA(sbp_msg) self.last_soln = soln table = [] soln.n = soln.n * 1e-3 soln.e = soln.e * 1e-3 soln.d = soln.d * 1e-3 soln.h_accuracy = soln.h_accuracy * 1e-3 soln.v_accuracy = soln.v_accuracy * 1e-3 dist = np.sqrt(soln.n**2 + soln.e**2 + soln.d**2) tow = soln.tow * 1e-3 if self.nsec is not None: tow += self.nsec * 1e-9 ((tloc, secloc), (tgps, secgps)) = log_time_strings(self.week, tow) if self.utc_time is not None: ((tutc, secutc)) = datetime_2_str(self.utc_time) if self.directory_name_b == '': filepath = time.strftime("baseline_log_%Y%m%d-%H%M%S.csv") else: filepath = os.path.join( self.directory_name_b, time.strftime("baseline_log_%Y%m%d-%H%M%S.csv")) if not self.logging_b: self.log_file = None if self.logging_b: if self.log_file is None: self.log_file = sopen(filepath, 'w') self.log_file.write( 'pc_time,gps_time,tow(sec),north(meters),east(meters),down(meters),h_accuracy(meters),v_accuracy(meters),' 'distance(meters),num_sats,flags,num_hypothesis\n') log_str_gps = '' if tgps != '' and secgps != 0: log_str_gps = "{0}:{1:06.6f}".format(tgps, float(secgps)) self.log_file.write( '%s,%s,%.3f,%.4f,%.4f,%.4f,%.4f,%.4f,%.4f,%d,%d,%d\n' % ("{0}:{1:06.6f}".format(tloc, float(secloc)), log_str_gps, tow, soln.n, soln.e, soln.d, soln.h_accuracy, soln.v_accuracy, dist, soln.n_sats, soln.flags, self.num_hyps)) self.log_file.flush() self.last_mode = get_mode(soln) if self.last_mode < 1: table.append(('GPS Week', EMPTY_STR)) table.append(('GPS TOW', EMPTY_STR)) table.append(('GPS Time', EMPTY_STR)) table.append(('UTC Time', EMPTY_STR)) table.append(('UTC Src', EMPTY_STR)) table.append(('N', EMPTY_STR)) table.append(('E', EMPTY_STR)) table.append(('D', EMPTY_STR)) table.append(('Horiz Acc', EMPTY_STR)) table.append(('Vert Acc', EMPTY_STR)) table.append(('Dist.', EMPTY_STR)) table.append(('Sats Used', EMPTY_STR)) table.append(('Flags', EMPTY_STR)) table.append(('Mode', EMPTY_STR)) else: self.last_btime_update = time.time() if self.week is not None: table.append(('GPS Week', str(self.week))) table.append(('GPS TOW', "{:.3f}".format(tow))) if self.week is not None: table.append(('GPS Time', "{0}:{1:06.3f}".format( tgps, float(secgps)))) if self.utc_time is not None: table.append(('UTC Time', "{0}:{1:06.3f}".format( tutc, float(secutc)))) table.append(('UTC Src', self.utc_source)) table.append(('N', soln.n)) table.append(('E', soln.e)) table.append(('D', soln.d)) table.append(('Horiz Acc', soln.h_accuracy)) table.append(('Vert Acc', soln.v_accuracy)) table.append(('Dist.', "{0:.3f}".format(dist))) table.append(('Sats Used', soln.n_sats)) table.append(('Flags', '0x%02x' % soln.flags)) table.append(('Mode', mode_dict[self.last_mode])) if self.heading is not None: table.append(('Heading', self.heading)) if self.age_corrections is not None: table.append(('Corr. Age [s]', self.age_corrections)) self.table = table # Rotate array, deleting oldest entries to maintain # no more than N in plot self.n[1:] = self.n[:-1] self.e[1:] = self.e[:-1] self.d[1:] = self.d[:-1] self.mode[1:] = self.mode[:-1] # Insert latest position if self.last_mode > 1: self.n[0], self.e[0], self.d[0] = soln.n, soln.e, soln.d else: self.n[0], self.e[0], self.d[0] = [np.NAN, np.NAN, np.NAN] self.mode[0] = self.last_mode def solution_draw(self): if self.running: GUI.invoke_later(self._solution_draw) def _solution_draw(self): self._clear_history() soln = self.last_soln if np.any(self.mode): float_indexer = (self.mode == FLOAT_MODE) fixed_indexer = (self.mode == FIXED_MODE) dgnss_indexer = (self.mode == DGNSS_MODE) if np.any(fixed_indexer): self.plot_data.set_data('n_fixed', self.n[fixed_indexer]) self.plot_data.set_data('e_fixed', self.e[fixed_indexer]) self.plot_data.set_data('d_fixed', self.d[fixed_indexer]) if np.any(float_indexer): self.plot_data.set_data('n_float', self.n[float_indexer]) self.plot_data.set_data('e_float', self.e[float_indexer]) self.plot_data.set_data('d_float', self.d[float_indexer]) if np.any(dgnss_indexer): self.plot_data.set_data('n_dgnss', self.n[dgnss_indexer]) self.plot_data.set_data('e_dgnss', self.e[dgnss_indexer]) self.plot_data.set_data('d_dgnss', self.d[dgnss_indexer]) # Update our last solution icon if self.last_mode == FIXED_MODE: self._reset_remove_current() self.plot_data.set_data('cur_fixed_n', [soln.n]) self.plot_data.set_data('cur_fixed_e', [soln.e]) self.plot_data.set_data('cur_fixed_d', [soln.d]) elif self.last_mode == FLOAT_MODE: self._reset_remove_current() self.plot_data.set_data('cur_float_n', [soln.n]) self.plot_data.set_data('cur_float_e', [soln.e]) self.plot_data.set_data('cur_float_d', [soln.d]) elif self.last_mode == DGNSS_MODE: self._reset_remove_current() self.plot_data.set_data('cur_dgnss_n', [soln.n]) self.plot_data.set_data('cur_dgnss_e', [soln.e]) self.plot_data.set_data('cur_dgnss_d', [soln.d]) else: pass # make the zoomall win over the position centered button # position centered button has no effect when zoom all enabled if not self.zoomall and self.position_centered: d = (self.plot.index_range.high - self.plot.index_range.low) / 2. self.plot.index_range.set_bounds(soln.e - d, soln.e + d) d = (self.plot.value_range.high - self.plot.value_range.low) / 2. self.plot.value_range.set_bounds(soln.n - d, soln.n + d) if self.zoomall: plot_square_axes(self.plot, ('e_fixed', 'e_float', 'e_dgnss'), ('n_fixed', 'n_float', 'n_dgnss')) def __init__(self, link, plot_history_max=1000, dirname=''): super(BaselineView, self).__init__() self.log_file = None self.directory_name_b = dirname self.num_hyps = 0 self.last_hyp_update = 0 self.last_btime_update = 0 self.last_soln = None self.last_mode = 0 self.plot_data = ArrayPlotData( n_fixed=[0.0], e_fixed=[0.0], d_fixed=[0.0], n_float=[0.0], e_float=[0.0], d_float=[0.0], n_dgnss=[0.0], e_dgnss=[0.0], d_dgnss=[0.0], t=[0.0], ref_n=[0.0], ref_e=[0.0], ref_d=[0.0], cur_fixed_e=[], cur_fixed_n=[], cur_fixed_d=[], cur_float_e=[], cur_float_n=[], cur_float_d=[], cur_dgnss_e=[], cur_dgnss_n=[], cur_dgnss_d=[]) self.plot_history_max = plot_history_max self.n = np.zeros(plot_history_max) self.e = np.zeros(plot_history_max) self.d = np.zeros(plot_history_max) self.mode = np.zeros(plot_history_max) self.plot = Plot(self.plot_data) pts_float = self.plot.plot( ('e_float', 'n_float'), type='scatter', color=color_dict[FLOAT_MODE], marker='dot', line_width=0.0, marker_size=1.0) pts_fixed = self.plot.plot( # noqa: F841 ('e_fixed', 'n_fixed'), type='scatter', color=color_dict[FIXED_MODE], marker='dot', line_width=0.0, marker_size=1.0) pts_dgnss = self.plot.plot( # noqa: F841 ('e_dgnss', 'n_dgnss'), type='scatter', color=color_dict[DGNSS_MODE], marker='dot', line_width=0.0, marker_size=1.0) ref = self.plot.plot( ('ref_e', 'ref_n'), type='scatter', color='red', marker='plus', marker_size=5, line_width=1.5) cur_fixed = self.plot.plot( ('cur_fixed_e', 'cur_fixed_n'), type='scatter', color=color_dict[FIXED_MODE], marker='plus', marker_size=5, line_width=1.5) cur_float = self.plot.plot( ('cur_float_e', 'cur_float_n'), type='scatter', color=color_dict[FLOAT_MODE], marker='plus', marker_size=5, line_width=1.5) cur_dgnss = self.plot.plot( ('cur_dgnss_e', 'cur_dgnss_n'), type='scatter', color=color_dict[DGNSS_MODE], marker='plus', line_width=1.5, marker_size=5) plot_labels = [' Base Position', 'DGPS', 'RTK Float', 'RTK Fixed'] plots_legend = dict( zip(plot_labels, [ref, cur_dgnss, cur_float, cur_fixed])) self.plot.legend.plots = plots_legend self.plot.legend.labels = plot_labels # sets order self.plot.legend.visible = True self.plot.index_axis.tick_label_position = 'inside' self.plot.index_axis.tick_label_color = 'gray' self.plot.index_axis.tick_color = 'gray' self.plot.index_axis.title = 'E (meters)' self.plot.index_axis.title_spacing = 5 self.plot.value_axis.tick_label_position = 'inside' self.plot.value_axis.tick_label_color = 'gray' self.plot.value_axis.tick_color = 'gray' self.plot.value_axis.title = 'N (meters)' self.plot.value_axis.title_spacing = 5 self.plot.padding = (25, 25, 25, 25) self.plot.tools.append(PanTool(self.plot)) zt = ZoomTool( self.plot, zoom_factor=1.1, tool_mode="box", always_on=False) self.plot.overlays.append(zt) self.week = None self.utc_time = None self.age_corrections = None self.heading = "---" self.nsec = 0 self.link = link self.link.add_callback(self.baseline_callback, [ SBP_MSG_BASELINE_NED, SBP_MSG_BASELINE_NED_DEP_A ]) self.link.add_callback(self.baseline_heading_callback, [SBP_MSG_BASELINE_HEADING]) self.link.add_callback(self.iar_state_callback, SBP_MSG_IAR_STATE) self.link.add_callback(self.gps_time_callback, [SBP_MSG_GPS_TIME, SBP_MSG_GPS_TIME_DEP_A]) self.link.add_callback(self.utc_time_callback, [SBP_MSG_UTC_TIME]) self.link.add_callback(self.age_corrections_callback, SBP_MSG_AGE_CORRECTIONS) call_repeatedly(0.2, self.solution_draw) self.python_console_cmds = {'baseline': self}
class WaitControl(Loggable): page_name = Str('Wait') message = Str message_color = Color('black') high = Float duration = Float(10) current_time = Float auto_start = Bool(False) timer = None end_evt = None continue_button = Button('Continue') _continued = Bool _canceled = Bool _no_update = False def __init__(self, *args, **kw): self.reset() super(WaitControl, self).__init__(*args, **kw) if self.auto_start: self.start(evt=self.end_evt) def is_active(self): if self.timer: return self.timer.isActive() def is_canceled(self): return self._canceled def is_continued(self): return self._continued def join(self, evt=None): if evt is None: evt = self.end_evt time.sleep(0.25) # while not self.end_evt.is_set(): while not evt.is_set(): # time.sleep(0.005) evt.wait(0.005) self.debug('Join finished') def start(self, block=True, evt=None, duration=None, message=None): if self.end_evt: self.end_evt.set() if evt is None: evt = Event() if evt: evt.clear() self.end_evt = evt if self.timer: self.timer.stop() self.timer.wait_for_completion() if duration: # self.duration = 1 self.duration = duration self.reset() if message: self.message = message self.timer = Timer(1000, self._update_time, delay=1000) self._continued = False if block: self.join(evt=evt) if evt == self.end_evt: self.end_evt = None def stop(self): self._end() self.debug('wait dialog stopped') if self.current_time > 1: self.message = 'Stopped' self.message_color = 'red' # self.current_time = 0 def reset(self): with no_update(self, fire_update_needed=False): self.high = self.duration self.current_time = self.duration # =============================================================================== # private # =============================================================================== def _continue(self): self._continued = True self._end() self.current_time = 0 def _end(self): self.message = '' if self.timer is not None: self.timer.Stop() if self.end_evt is not None: self.end_evt.set() def _update_time(self): ct = self.current_time if self.timer and self.timer.isActive(): self.current_time -= 1 ct -= 1 # self.debug('Current Time={}/{}'.format(ct, self.duration)) if ct <= 0: self._end() self._canceled = False else: self.current_time = ct # def _current_time_changed(self): # if self.current_time <= 0: # self._end() # self._canceled = False # =============================================================================== # handlers # =============================================================================== def _continue_button_fired(self): self._continue() def _high_changed(self, v): if self._no_update: return self.duration = v self.current_time = v
class IntervalSpectrogram(PlotsInterval): name = 'Spectrogram' # overload the figure stack to use GridSpec _figstack = Instance(GridSpecStack, args=()) ## channel = Str('all') ## _chan_list = Property(depends_on='parent.chan_map') # estimations details NW = Enum(2.5, np.arange(2, 11, 0.5).tolist()) lag = Float(25) # ms strip = Float(100) # ms detrend = Bool(True) adaptive = Bool(True) over_samp = Range(0, 16, mode='spinner', value=4) high_res = Bool(True) _bandwidth = Property(Float, depends_on='NW, strip') baseline = Float(1.0) # sec normalize = Enum('Z score', ('None', 'Z score', 'divide', 'subtract')) plot = Button('Plot') freq_hi = Float freq_lo = Float(0) log_freq = Bool(False) new_figure = Bool(True) colormap = 'Spectral_r' def __init__(self, **traits): HasTraits.__init__(self, **traits) self.freq_hi = round((self.parent.x_scale**-1.0) / 2.0) # @property_depends_on('NW') @cached_property def _get__bandwidth(self): #t1, t2 = self.parent._qtwindow.current_frame() T = self.strip * 1e-3 # TW = NW --> W = NW / T return 2.0 * self.NW / T def __default_mtm_kwargs(self, strip): kw = dict() kw['NW'] = self.NW kw['adaptive_weights'] = self.adaptive Fs = self.parent.x_scale**-1.0 kw['Fs'] = Fs kw['jackknife'] = False kw['detrend'] = 'linear' if self.detrend else '' lag = int(Fs * self.lag / 1000.) if strip is None: strip = int(Fs * self.strip / 1000.) kw['nfft'] = nextpow2(strip) kw['pl'] = 1.0 - float(lag) / strip if self.high_res: kw['samp_factor'] = self.over_samp kw['low_bias'] = 0.95 return kw, strip def _normalize(self, tx, ptf, mode, tc): # ptf should be ([n_chan], n_freq, n_time) if not mode: mode = self.normalize if not tc: tc = self.baseline m = tx < tc if not m.any(): print('Baseline too short: using 2 bins') m[:2] = True if ptf.ndim < 3: ptf = ptf.reshape((1, ) + ptf.shape) nf = ptf.shape[1] # subtraction and z-score normalizations should be done with log transform if mode.lower() in ('subtract', 'z score'): ptf = np.log10(ptf) p_bl = ptf[..., m].transpose(0, 2, 1).copy().reshape(-1, nf) # get mean of baseline for each freq. jn = Jackknife(p_bl, axis=0) mn = jn.estimate(np.mean, se=False) # smooth out mean across frequencies (5 hz sigma) but only replace after 2NW points bias_pts = int(2 * self.NW) mn[bias_pts:] = gaussian_filter1d(mn, self.NW, mode='reflect')[bias_pts:] assert len(mn) == nf, '?' if mode.lower() == 'divide': ptf = ptf.mean(0) / mn[:, None] elif mode.lower() == 'subtract': ptf = ptf.mean(0) - mn[:, None] elif mode.lower() == 'z score': stdev = jn.estimate(np.std, se=False) stdev[bias_pts:] = gaussian_filter1d(stdev, self.NW, mode='reflect')[bias_pts:] ptf = (ptf.mean(0) - mn[:, None]) / stdev[:, None] return ptf def highres_spectrogram(self, array, strip=None, **mtm_kw): kw, strip = self.__default_mtm_kwargs(strip) kw.update(**mtm_kw) tx, fx, ptf = msp.mtm_spectrogram(array, strip, **kw) n_cut = 6 tx = tx[n_cut:-n_cut] ptf = ptf[..., n_cut:-n_cut].copy() return tx, fx, ptf def spectrogram(self, array, strip=None, **mtm_kw): kw, strip = self.__default_mtm_kwargs(strip) kw.update(**mtm_kw) tx, fx, ptf = msp.mtm_spectrogram_basic(array, strip, **kw) return tx, fx, ptf def _plot_fired(self): x, y = self.curve_manager.interactive_curve.current_data() y *= 1e6 if self.channel.lower() != 'all': i, j = list(map(float, self.channel.split(','))) y = y[self.chan_map.lookup(i, j)] if self.high_res: tx, fx, ptf = self.highres_spectrogram(y) else: tx, fx, ptf = self.spectrogram(y) if self.normalize.lower() != 'none': ptf = self._normalize(tx, ptf, self.normalize, self.baseline) else: ptf = np.log10(ptf) if ptf.ndim > 2: ptf = ptf.mean(0) # cut out non-display frequencies and advance timebase to match window m = (fx >= self.freq_lo) & (fx <= self.freq_hi) ptf = ptf[m] fx = fx[m] if fx[0] == 0 and self.log_freq: fx[0] = 0.5 * (fx[0] + fx[1]) tx += x[0] # fig, ax = self._get_fig() fig, gs = self._get_fig(figsize=(8, 10), nrows=2, ncols=2, height_ratios=[1, 3], width_ratios=[20, 1]) ts_ax = fig.add_subplot(gs[0, 0]) ts_ax.plot(x, y) ts_ax.set_ylabel(r'$\mu$V') for t in ts_ax.get_xticklabels(): t.set_visible(False) sg_ax = fig.add_subplot(gs[1, 0]) im = sg_ax.imshow(ptf, extent=[tx[0], tx[-1], fx[0], fx[-1]], cmap=self.colormap, origin='lower') if self.log_freq: sg_ax.set_yscale('log') sg_ax.set_ylim(fx[1], fx[-1]) sg_ax.axis('auto') sg_ax.set_xlabel('Time (s)') sg_ax.set_ylabel('Frequency (Hz)') cb_ax = fig.add_subplot(gs[:, 1]) cbar = fig.colorbar(im, cax=cb_ax, aspect=80) cbar.locator = ticker.MaxNLocator(nbins=3) if self.normalize.lower() == 'divide': title = 'Normalized ratio' elif self.normalize.lower() == 'subtract': title = 'Baseline subtracted' elif self.normalize.lower() == 'z score': title = 'Z score' else: title = 'Log-power' cbar.set_label(title) # fig.tight_layout() gs.tight_layout(fig, w_pad=0.025) ts_ax.set_xlim(sg_ax.get_xlim()) try: fig.canvas.draw_idle() except: pass def default_traits_view(self): v = View( HGroup( VGroup( HGroup( VGroup( Label('Channel to plot'), UItem( 'channel', editor=EnumEditor(name='object._chan_list'))), Label('Log-Hz?'), UItem('log_freq'), UItem('plot')), HGroup(Item('freq_lo', label='low', width=3), Item('freq_hi', label='high', width=3), label='Freq. Range')), VGroup(Item('high_res', label='High-res SG'), Item('normalize', label='Normalize'), Item('baseline', label='Baseline length (sec)'), label='Spectrogram setup'), VGroup(Item('NW'), Item('_bandwidth', label='BW (Hz)', style='readonly', width=4), Item('strip', label='SG strip len (ms)'), Item('lag', label='SG lag (ms)'), Item('detrend', label='Detrend window'), Item('adaptive', label='Adaptive MTM'), label='Estimation details', columns=2), VGroup(Item('over_samp', label='Oversamp high-res SG'), enabled_when='high_res'))) return v
class IonOpticsManager(Manager): reference_detector = Instance(BaseDetector) reference_isotope = Any magnet_dac = Range(0.0, 6.0) graph = Instance(Graph) peak_center_button = Button('Peak Center') stop_button = Button('Stop') alive = Bool(False) spectrometer = Any peak_center = Instance(PeakCenter) coincidence = Instance(Coincidence) peak_center_config = Instance(PeakCenterConfigurer) # coincidence_config = Instance(CoincidenceConfig) canceled = False peak_center_result = None _centering_thread = None def close(self): self.cancel_peak_center() def cancel_peak_center(self): self.alive = False self.canceled = True self.peak_center.canceled = True self.peak_center.stop() self.info('peak center canceled') def get_mass(self, isotope_key): spec = self.spectrometer molweights = spec.molecular_weights return molweights[isotope_key] def mftable_ctx(self, mftable): return MFTableCTX(self, mftable) def set_mftable(self, name=None): """ if mt is None set to the default mftable located at setupfiles/spectrometer/mftable.csv :param mt: :return: """ if name and name != os.path.splitext(os.path.basename( paths.mftable))[0]: self.spectrometer.use_deflection_correction = False else: self.spectrometer.use_deflection_correction = True self.spectrometer.magnet.set_mftable(name) def get_position(self, *args, **kw): kw['update_isotopes'] = False return self._get_position(*args, **kw) def av_position(self, pos, detector, *args, **kw): av = self._get_av_position(pos, detector) self.spectrometer.source.set_hv(av) self.info('positioning {} ({}) on {}'.format(pos, av, detector)) return av def position(self, pos, detector, use_af_demag=True, *args, **kw): dac = self._get_position(pos, detector, *args, **kw) mag = self.spectrometer.magnet self.info('positioning {} ({}) on {}'.format(pos, dac, detector)) return mag.set_dac(dac, use_af_demag=use_af_demag) def do_coincidence_scan(self, new_thread=True): if new_thread: t = Thread(name='ion_optics.coincidence', target=self._coincidence) t.start() self._centering_thread = t def setup_coincidence(self): pcc = self.coincidence_config pcc.dac = self.spectrometer.magnet.dac info = pcc.edit_traits() if not info.result: return detector = pcc.detector isotope = pcc.isotope detectors = [d for d in pcc.additional_detectors] # integration_time = pcc.integration_time if pcc.use_nominal_dac: center_dac = self.get_position(isotope, detector) elif pcc.use_current_dac: center_dac = self.spectrometer.magnet.dac else: center_dac = pcc.dac # self.spectrometer.save_integration() # self.spectrometer.set_integration(integration_time) cs = Coincidence(spectrometer=self.spectrometer, center_dac=center_dac, reference_detector=detector, reference_isotope=isotope, additional_detectors=detectors) self.coincidence = cs return cs def get_center_dac(self, det, iso): spec = self.spectrometer det = spec.get_detector(det) molweights = spec.molecular_weights mass = molweights[iso] dac = spec.magnet.map_mass_to_dac(mass, det.name) # correct for deflection return spec.correct_dac(det, dac) def do_peak_center(self, save=True, confirm_save=False, warn=False, new_thread=True, message='', on_end=None, timeout=None): self.debug('doing pc') self.canceled = False self.alive = True self.peak_center_result = None args = (save, confirm_save, warn, message, on_end, timeout) if new_thread: t = Thread(name='ion_optics.peak_center', target=self._peak_center, args=args) t.start() self._centering_thread = t return t else: self._peak_center(*args) def setup_peak_center(self, detector=None, isotope=None, integration_time=1.04, directions='Increase', center_dac=None, name='', show_label=False, window=0.015, step_width=0.0005, min_peak_height=1.0, percent=80, deconvolve=None, use_interpolation=False, interpolation_kind='linear', dac_offset=None, calculate_all_peaks=False, config_name=None, use_configuration_dac=True, new=False, update_others=True, plot_panel=None): if deconvolve is None: n_peaks, select_peak = 1, 1 use_dac_offset = False if dac_offset is not None: use_dac_offset = True spec = self.spectrometer pcconfig = self.peak_center_config spec.save_integration() self.debug('setup peak center. detector={}, isotope={}'.format( detector, isotope)) pcc = None dataspace = 'dac' use_accel_voltage = False use_extend = False self._setup_config() if config_name: pcconfig.load() pcconfig.active_name = config_name pcc = pcconfig.active_item elif detector is None or isotope is None: self.debug('ask user for peak center configuration') pcconfig.load() if config_name: pcconfig.active_name = config_name info = pcconfig.edit_traits() if not info.result: return else: pcc = pcconfig.active_item if pcc: if not detector: detector = pcc.active_detectors if not isotope: isotope = pcc.isotope directions = pcc.directions integration_time = pcc.integration_time dataspace = pcc.dataspace use_accel_voltage = pcc.use_accel_voltage use_extend = pcc.use_extend window = pcc.window min_peak_height = pcc.min_peak_height step_width = pcc.step_width percent = pcc.percent use_interpolation = pcc.use_interpolation interpolation_kind = pcc.interpolation_kind n_peaks = pcc.n_peaks select_peak = pcc.select_n_peak use_dac_offset = pcc.use_dac_offset dac_offset = pcc.dac_offset calculate_all_peaks = pcc.calculate_all_peaks update_others = pcc.update_others if not pcc.use_mftable_dac and center_dac is None and use_configuration_dac: center_dac = pcc.dac spec.set_integration_time(integration_time) period = int(integration_time * 1000 * 0.9) if not isinstance(detector, (tuple, list)): detector = (detector, ) ref = spec.get_detector(detector[0]) if center_dac is None: center_dac = self.get_center_dac(ref, isotope) # if mass: # mag = spec.magnet # center_dac = mag.map_mass_to_dac(mass, ref) # low = mag.map_mass_to_dac(mass - window / 2., ref) # high = mag.map_mass_to_dac(mass + window / 2., ref) # window = high - low # step_width = abs(mag.map_mass_to_dac(mass + step_width, ref) - center_dac) if len(detector) > 1: ad = detector[1:] else: ad = [] pc = self.peak_center klass = AccelVoltagePeakCenter if use_accel_voltage else PeakCenter if not pc or new or (use_accel_voltage and not isinstance(pc, AccelVoltagePeakCenter)): pc = klass() pc.trait_set(center_dac=center_dac, dataspace=dataspace, use_accel_voltage=use_accel_voltage, use_extend=use_extend, period=period, window=window, percent=percent, min_peak_height=min_peak_height, step_width=step_width, directions=directions, reference_detector=ref, additional_detectors=ad, reference_isotope=isotope, spectrometer=spec, show_label=show_label, use_interpolation=use_interpolation, interpolation_kind=interpolation_kind, n_peaks=n_peaks, select_peak=select_peak, use_dac_offset=use_dac_offset, dac_offset=dac_offset, calculate_all_peaks=calculate_all_peaks, update_others=update_others) graph = pc.graph graph.name = name if plot_panel: plot_panel.set_peak_center_graph(graph) self.peak_center = pc self.reference_detector = ref self.reference_isotope = isotope return self.peak_center def backup_mftable(self): self.spectrometer.magnet.field_table.backup() # private def _setup_config(self): config = self.peak_center_config config.detectors = self.spectrometer.detector_names keys = list(self.spectrometer.molecular_weights.keys()) config.isotopes = sort_isotopes(keys) config.integration_times = self.spectrometer.integration_times # def _get_peak_center_config(self, config_name): # if config_name is None: # config_name = 'default' # # config = self.peak_center_config.get(config_name) # # config.detectors = self.spectrometer.detectors_names # if config.detector_name: # config.detector = next((di for di in config.detectors if di == config.detector_name), None) # # if not config.detector: # config.detector = config.detectors[0] # # keys = self.spectrometer.molecular_weights.keys() # config.isotopes = sort_isotopes(keys) # config.integration_times = self.spectrometer.integration_times # return config # def _timeout_func(self, timeout, evt): # st = time.time() # while not evt.is_set(): # if not self.alive: # break # # if time.time() - st > timeout: # self.warning('Peak Centering timed out after {}s'.format(timeout)) # self.cancel_peak_center() # break # # time.sleep(0.01) def _peak_center(self, save, confirm_save, warn, message, on_end, timeout): pc = self.peak_center spec = self.spectrometer ref = self.reference_detector isotope = self.reference_isotope try: center_value = pc.get_peak_center() except NoIntensityChange as e: self.warning( 'Peak Centering failed. No Intensity change. {}'.format(e)) center_value = None self.peak_center_result = center_value if center_value: det = spec.get_detector(ref) if pc.use_accel_voltage: args = ref, isotope, center_value else: dac_a = spec.uncorrect_dac(det, center_value) self.info( 'dac uncorrected for HV and deflection {}'.format(dac_a)) args = ref, isotope, dac_a self.adjusted_peak_center_result = dac_a self.info('new center pos {} ({}) @ {}'.format(*args)) if save: if confirm_save: msg = 'Update Magnet Field Table with new peak center- {} ({}) @ RefDetUnits= {}'.format( *args) if pc.use_accel_voltage: msg = 'Update Accel Voltage Table with new peak center- {} ({}) @ RefDetUnits= {}'.format( *args) save = self.confirmation_dialog(msg) if save: if pc.use_accel_voltage: spec.source.update_field_table(det, isotope, center_value, message) else: spec.magnet.update_field_table( det, isotope, dac_a, message, update_others=pc.update_others) spec.magnet.set_dac(dac_a) elif not self.canceled: msg = 'centering failed' if warn: self.warning_dialog(msg) self.warning(msg) # needs to be called on the main thread to properly update # the menubar actions. alive=False enables IonOptics>Peak Center # d = lambda:self.trait_set(alive=False) # still necessary with qt? and tasks if on_end: on_end() self.trait_set(alive=False) self.spectrometer.restore_integration() def _get_av_position(self, pos, detector, update_isotopes=True): self.debug('AV POSITION {} {}'.format(pos, detector)) spec = self.spectrometer if not isinstance(detector, str): detector = detector.name if isinstance(pos, str): try: pos = float(pos) except ValueError: # pos is isotope if update_isotopes: # if the pos is an isotope then update the detectors spec.update_isotopes(pos, detector) pos = self.get_mass(pos) # pos is mass i.e 39.962 av = spec.source.map_mass_to_hv(pos, detector) return av def _get_position(self, pos, detector, use_dac=False, update_isotopes=True): """ pos can be str or float "Ar40", "39.962", 39.962 to set in DAC space set use_dac=True """ if pos == NULL_STR: return spec = self.spectrometer mag = spec.magnet if isinstance(detector, str): det = spec.get_detector(detector) else: det = detector self.debug('detector {}'.format(det)) if use_dac: dac = pos else: self.debug('POSITION {} {}'.format(pos, detector)) if isinstance(pos, str): try: pos = float(pos) except ValueError: # pos is isotope if update_isotopes: # if the pos is an isotope then update the detectors spec.update_isotopes(pos, detector) pos = self.get_mass(pos) mag.mass_change(pos) # pos is mass i.e 39.962 print('det is', det) dac = mag.map_mass_to_dac(pos, det.name) dac = spec.correct_dac(det, dac) return dac def _coincidence(self): self.coincidence.get_peak_center() self.info('coincidence finished') self.spectrometer.restore_integration() # =============================================================================== # handler # =============================================================================== def _coincidence_config_default(self): config = None p = os.path.join(paths.hidden_dir, 'coincidence_config.p') if os.path.isfile(p): try: with open(p) as rfile: config = pickle.load(rfile) config.detectors = dets = self.spectrometer.detectors config.detector = next( (di for di in dets if di.name == config.detector_name), None) except Exception as e: print('coincidence config', e) if config is None: config = CoincidenceConfig() config.detectors = self.spectrometer.detectors config.detector = config.detectors[0] keys = list(self.spectrometer.molecular_weights.keys()) config.isotopes = sort_isotopes(keys) return config def _peak_center_config_default(self): config = PeakCenterConfigurer() return config
class GridCell( SDomain ): ''' A single mgrid cell for geometrical representation of the domain. Based on the grid_cell_spec attribute, the node distribution is determined. ''' # Everything depends on the grid_cell_specification # grid_cell_spec = Instance( CellSpec ) def _grid_cell_spec_default( self ): return CellSpec() # Generated grid cell coordinates as they come from mgrid. # The dimensionality of the mgrid comes from the # grid_cell_spec_attribute # grid_cell_coords = Property( depends_on = 'grid_cell_spec' ) @cached_property def _get_grid_cell_coords( self ): grid_cell = mgrid[ self.grid_cell_spec.get_cell_slices() ] return c_[ tuple( [ x.flatten() for x in grid_cell ] ) ] n_nodes = Property( depends_on = 'grid_cell_spec' ) @cached_property def _get_n_nodes( self ): '''Return the number of all nodes within the cell. ''' return self.grid_cell_coords.shape[0] # Node map lists the active nodes within the grid cell # in the specified order # node_map = Property( Array( int ), depends_on = 'grid_cell_spec' ) @cached_property def _get_node_map( self ): n_map = [] for node in self.grid_cell_spec._node_array: for idx, grid_cell_node in enumerate( self.grid_cell_coords ): if allclose( node , grid_cell_node , atol = 1.0e-3 ): n_map.append( idx ) continue return array( n_map, int ) # #----------------------------------------------------------------- # # Visualization related methods # #----------------------------------------------------------------- # mvp_mgrid_ngeo_labels = Trait( MVPointLabels ) # def _mvp_mgrid_ngeo_labels_default( self ): # return MVPointLabels( name = 'Geo node numbers', # points = self._get_points, # scalars = self._get_node_distribution ) refresh_button = Button( 'Draw' ) @on_trait_change( 'refresh_button' ) def redraw( self ): ''' ''' self.mvp_mgrid_ngeo_labels.redraw( 'label_scalars' ) def _get_points( self ): points = self.grid_cell_coords[ ix_( self.node_map ) ] shape = points.shape if shape[1] < 3: _points = zeros( ( shape[0], 3 ), dtype = float ) _points[:, 0:shape[1]] = points return _points else: return points def _get_node_distribution( self ): #return arange(len(self.node_map)) n_points = self.grid_cell_coords.shape[0] full_node_map = ones( n_points, dtype = float ) * -1. full_node_map[ ix_( self.node_map ) ] = arange( len( self.node_map ) ) return full_node_map def __getitem__( self, idx ): # construct the full boolean map of the grid cell' node_bool_map = repeat( False, self.n_nodes ).reshape( self.grid_cell_spec.get_cell_shape() ) # put true at the sliced positions node_bool_map[idx] = True # extract the used nodes using the node map node_selection = node_bool_map.flatten()[ self.node_map ] return node_selection #------------------------------------------------------------------ # UI - related methods #------------------------------------------------------------------ traits_view = View( Item( 'grid_cell_spec' ), Item( 'refresh_button' ), Item( 'node_map' ), resizable = True, height = 0.5, width = 0.5 )