class AdaptationOffer(HasTraits): """ An offer to provide adapters from one protocol to another. An adaptation offer consists of a factory that can create adapters, and the protocols that define what the adapters adapt from and to. """ #### 'object' protocol #################################################### def __repr__(self): """ Return a string representation of the object. """ template = "<AdaptationOffer: '{from_}' -> '{to}'>" from_ = self.from_protocol_name to = self.to_protocol_name return template.format(from_=from_, to=to) #### 'AdaptationOffer' protocol ########################################### #: A factory for creating adapters. #: #: The factory must ba callable that takes exactly one argument which is #: the object to be adapted (known as the adaptee), and returns an #: adapter from the `from_protocol` to the `to_protocol`. #: #: The factory can be specified as either a callable, or a string in the #: form 'foo.bar.baz' which is turned into an import statement #: 'from foo.bar import baz' and imported when the trait is first accessed. factory = Property(Any) #: Adapters created by the factory adapt *from* this protocol. #: #: The protocol can be specified as a protocol (class/Interface), or a #: string in the form 'foo.bar.baz' which is turned into an import #: statement 'from foo.bar import baz' and imported when the trait is #: accessed. from_protocol = Property(Any) from_protocol_name = Property(Any) def _get_from_protocol_name(self): return self._get_type_name(self._from_protocol) #: Adapters created by the factory adapt *to* this protocol. #: #: The protocol can be specified as a protocol (class/Interface), or a #: string in the form 'foo.bar.baz' which is turned into an import #: statement 'from foo.bar import baz' and imported when the trait is #: accessed. to_protocol = Property(Any) to_protocol_name = Property(Any) def _get_to_protocol_name(self): return self._get_type_name(self._to_protocol) #### Private protocol ###################################################### #: Shadow trait for the corresponding property. _factory = Any _factory_loaded = Bool(False) def _get_factory(self): """ Trait property getter. """ if not self._factory_loaded: if isinstance(self._factory, str): self._factory = import_symbol(self._factory) self._factory_loaded = True return self._factory def _set_factory(self, factory): """ Trait property setter. """ self._factory = factory return #: Shadow trait for the corresponding property. _from_protocol = Any _from_protocol_loaded = Bool(False) def _get_from_protocol(self): """ Trait property getter. """ if not self._from_protocol_loaded: if isinstance(self._from_protocol, str): self._from_protocol = import_symbol(self._from_protocol) self._from_protocol_loaded = True return self._from_protocol def _set_from_protocol(self, from_protocol): """ Trait property setter. """ self._from_protocol = from_protocol return #: Shadow trait for the corresponding property. _to_protocol = Any _to_protocol_loaded = Bool(False) def _get_to_protocol(self): """ Trait property getter. """ if not self._to_protocol_loaded: if isinstance(self._to_protocol, str): self._to_protocol = import_symbol(self._to_protocol) self._to_protocol_loaded = True return self._to_protocol def _set_to_protocol(self, to_protocol): """ Trait property setter. """ self._to_protocol = to_protocol return def _get_type_name(self, type_or_type_name): """ Returns the full dotted path for a type. For example: from traits.api import HasTraits _get_type_name(HasTraits) == 'traits.has_traits.HasTraits' If the type is given as a string (e.g., for lazy loading), it is just returned. """ if isinstance(type_or_type_name, str): type_name = type_or_type_name else: type_name = "{module}.{name}".format( module=type_or_type_name.__module__, name=type_or_type_name.__name__) return type_name
class TableEditor(Editor, BaseTableEditor): """ Editor that presents data in a table. Optionally, tables can have a set of filters that reduce the set of data displayed, according to their criteria. """ #------------------------------------------------------------------------- # Trait definitions: #------------------------------------------------------------------------- # The table view control associated with the editor: table_view = Any def _table_view_default(self): return TableView(editor=self) # A wrapper around the source model which provides filtering and sorting: model = Instance(SortFilterTableModel) def _model_default(self): return SortFilterTableModel(editor=self) # The table model associated with the editor: source_model = Instance(TableModel) def _source_model_default(self): return TableModel(editor=self) # The set of columns currently defined on the editor: columns = List(TableColumn) # The currently selected row(s), column(s), or cell(s). selected = Any # The current selected row selected_row = Property(Any, depends_on='selected') selected_indices = Property(Any, depends_on='selected') # Current filter object (should be a TableFilter or callable or None): filter = Any # The indices of the table items currently passing the table filter: filtered_indices = List(Int) # Current filter summary message filter_summary = Str('All items') # Update the filtered contents. update_filter = Event() # The event fired when a cell is clicked on: click = Event # The event fired when a cell is double-clicked on: dclick = Event # The Traits UI associated with the table editor toolbar: toolbar_ui = Instance(UI) # The context menu associated with empty space in the table empty_menu = Instance(QtGui.QMenu) # The context menu associated with the vertical header header_menu = Instance(QtGui.QMenu) # The context menu actions for moving rows up and down header_menu_up = Instance(QtGui.QAction) header_menu_down = Instance(QtGui.QAction) # The index of the row that was last right clicked on its vertical header header_row = Int # Whether to auto-size the columns or not. auto_size = Bool(False) # Dictionary mapping image names to QIcons images = Any({}) # Dictionary mapping ImageResource objects to QIcons image_resources = Any({}) # An image being converted: image = Image #------------------------------------------------------------------------- # Finishes initializing the editor by creating the underlying toolkit # widget: #------------------------------------------------------------------------- def init(self, parent): """Finishes initializing the editor by creating the underlying toolkit widget.""" factory = self.factory self.columns = factory.columns[:] if factory.table_view_factory is not None: self.table_view = factory.table_view_factory(editor=self) if factory.source_model_factory is not None: self.source_model = factory.source_model_factory(editor=self) if factory.model_factory is not None: self.model = factory.model_factory(editor=self) # Create the table view and model self.model.setDynamicSortFilter(True) self.model.setSourceModel(self.source_model) self.table_view.setModel(self.model) # Create the vertical header context menu and connect to its signals self.header_menu = QtGui.QMenu(self.table_view) signal = QtCore.SIGNAL('triggered()') insertable = factory.row_factory is not None and not factory.auto_add if factory.editable: if insertable: action = self.header_menu.addAction('Insert new item') QtCore.QObject.connect(action, signal, self._on_context_insert) if factory.deletable: action = self.header_menu.addAction('Delete item') QtCore.QObject.connect(action, signal, self._on_context_remove) if factory.reorderable: if factory.editable and (insertable or factory.deletable): self.header_menu.addSeparator() self.header_menu_up = self.header_menu.addAction('Move item up') QtCore.QObject.connect(self.header_menu_up, signal, self._on_context_move_up) self.header_menu_down = self.header_menu.addAction( 'Move item down') QtCore.QObject.connect(self.header_menu_down, signal, self._on_context_move_down) # Create the empty space context menu and connect its signals self.empty_menu = QtGui.QMenu(self.table_view) action = self.empty_menu.addAction('Add new item') QtCore.QObject.connect(action, signal, self._on_context_append) # When sorting is enabled, the first column is initially displayed with # the triangle indicating it is the sort index, even though no sorting # has actually been done. Sort here for UI/model consistency. if self.factory.sortable and not self.factory.reorderable: self.model.sort(0, QtCore.Qt.AscendingOrder) # Connect to the mode specific selection handler and select the first # row/column/cell. Do this before creating the edit_view to make sure # that it has a valid item to use when constructing its view. smodel = self.table_view.selectionModel() signal = QtCore.SIGNAL( 'selectionChanged(QItemSelection, QItemSelection)') mode_slot = getattr(self, '_on_%s_selection' % factory.selection_mode) QtCore.QObject.connect(smodel, signal, mode_slot) self.table_view.setCurrentIndex(self.model.index(0, 0)) # Create the toolbar if necessary if factory.show_toolbar and len(factory.filters) > 0: main_view = QtGui.QWidget() layout = QtGui.QVBoxLayout(main_view) layout.setContentsMargins(0, 0, 0, 0) self.toolbar_ui = self.edit_traits( parent=parent, kind='subpanel', view=View(Group(Item('filter{View}', editor=factory._filter_editor), Item('filter_summary{Results}', style='readonly'), spring, orientation='horizontal'), resizable=True)) self.toolbar_ui.parent = self.ui layout.addWidget(self.toolbar_ui.control) layout.addWidget(self.table_view) else: main_view = self.table_view # Create auxillary editor and encompassing splitter if necessary mode = factory.selection_mode if (factory.edit_view == ' ') or not mode in ('row', 'rows'): self.control = main_view else: if factory.orientation == 'horizontal': self.control = QtGui.QSplitter(QtCore.Qt.Horizontal) else: self.control = QtGui.QSplitter(QtCore.Qt.Vertical) self.control.setSizePolicy(QtGui.QSizePolicy.Expanding, QtGui.QSizePolicy.Expanding) self.control.addWidget(main_view) self.control.setStretchFactor(0, 2) # Create the row editor below the table view editor = InstanceEditor(view=factory.edit_view, kind='subpanel') self._ui = self.edit_traits( parent=self.control, kind='subpanel', view=View(Item('selected_row', style='custom', editor=editor, show_label=False, resizable=True, width=factory.edit_view_width, height=factory.edit_view_height), resizable=True, handler=factory.edit_view_handler)) self._ui.parent = self.ui self.control.addWidget(self._ui.control) self.control.setStretchFactor(1, 1) # Connect to the click and double click handlers signal = QtCore.SIGNAL('clicked(QModelIndex)') QtCore.QObject.connect(self.table_view, signal, self._on_click) signal = QtCore.SIGNAL('doubleClicked(QModelIndex)') QtCore.QObject.connect(self.table_view, signal, self._on_dclick) # Make sure we listen for 'items' changes as well as complete list # replacements self.context_object.on_trait_change( self.update_editor, self.extended_name + '_items', dispatch='ui') # Listen for changes to traits on the objects in the list self.context_object.on_trait_change( self.refresh_editor, self.extended_name + '.-', dispatch='ui') # Listen for changes on column definitions self.on_trait_change(self._update_columns, 'columns', dispatch='ui') self.on_trait_change(self._update_columns, 'columns_items', dispatch='ui') # Set up the required externally synchronized traits is_list = (mode in ('rows', 'columns', 'cells')) self.sync_value(factory.click, 'click', 'to') self.sync_value(factory.dclick, 'dclick', 'to') self.sync_value(factory.columns_name, 'columns', is_list=True) self.sync_value(factory.selected, 'selected', is_list=is_list) self.sync_value( factory.selected_indices, 'selected_indices', is_list=is_list) self.sync_value(factory.filter_name, 'filter', 'from') self.sync_value(factory.filtered_indices, 'filtered_indices', 'to') self.sync_value(factory.update_filter_name, 'update_filter', 'from') self.auto_size = self.factory.auto_size # Initialize the ItemDelegates for each column self._update_columns() #------------------------------------------------------------------------- # Disposes of the contents of an editor: #------------------------------------------------------------------------- def dispose(self): """ Disposes of the contents of an editor.""" # Make sure that the auxillary UIs are properly disposed if self.toolbar_ui is not None: self.toolbar_ui.dispose() if self._ui is not None: self._ui.dispose() # Remove listener for 'items' changes on object trait self.context_object.on_trait_change( self.update_editor, self.extended_name + '_items', remove=True) # Remove listener for changes to traits on the objects in the list self.context_object.on_trait_change( self.refresh_editor, self.extended_name + '.-', remove=True) # Remove listeners for column definition changes self.on_trait_change(self._update_columns, 'columns', remove=True) self.on_trait_change( self._update_columns, 'columns_items', remove=True) super(TableEditor, self).dispose() #------------------------------------------------------------------------- # Updates the editor when the object trait changes external to the editor: #------------------------------------------------------------------------- def update_editor(self): """Updates the editor when the object trait changes externally to the editor.""" if self._no_notify: return self.table_view.setUpdatesEnabled(False) try: filtering = len( self.factory.filters) > 0 or self.filter is not None if filtering: self._update_filtering() # invalidate the model, but do not reset it. Resetting the model # may cause problems if the selection sync'ed traits are being used # externally to manage the selections self.model.invalidate() self.table_view.resizeColumnsToContents() if self.auto_size: self.table_view.resizeRowsToContents() finally: self.table_view.setUpdatesEnabled(True) def restore_prefs(self, prefs): """ Restores any saved user preference information associated with the editor. """ header = self.table_view.horizontalHeader() if header is not None and 'column_state' in prefs: header.restoreState(prefs['column_state']) def save_prefs(self): """ Returns any user preference information associated with the editor. """ prefs = {} header = self.table_view.horizontalHeader() if header is not None: prefs['column_state'] = str(header.saveState()) return prefs #------------------------------------------------------------------------- # Requests that the underlying table widget to redraw itself: #------------------------------------------------------------------------- def refresh_editor(self): """Requests that the underlying table widget to redraw itself.""" self.table_view.viewport().update() #------------------------------------------------------------------------- # Creates a new row object using the provided factory: #------------------------------------------------------------------------- def create_new_row(self): """Creates a new row object using the provided factory.""" factory = self.factory kw = factory.row_factory_kw.copy() if '__table_editor__' in kw: kw['__table_editor__'] = self return self.ui.evaluate(factory.row_factory, *factory.row_factory_args, **kw) #------------------------------------------------------------------------- # Returns the raw list of model objects: #------------------------------------------------------------------------- def items(self): """Returns the raw list of model objects.""" items = self.value if not isinstance(items, SequenceTypes): items = [items] if self.factory and self.factory.reverse: items = ReversedList(items) return items #------------------------------------------------------------------------- # Perform actions without notifying the underlying table view or model: #------------------------------------------------------------------------- def callx(self, func, *args, **kw): """Call a function without notifying the underlying table view or model.""" old = self._no_notify self._no_notify = True try: func(*args, **kw) finally: self._no_notify = old def setx(self, **keywords): """Set one or more attributes without notifying the underlying table view or model.""" old = self._no_notify self._no_notify = True try: for name, value in keywords.items(): setattr(self, name, value) finally: self._no_notify = old #------------------------------------------------------------------------- # Sets the current selection to a set of specified objects: #------------------------------------------------------------------------- def set_selection(self, objects=[], notify=True): """Sets the current selection to a set of specified objects.""" if not isinstance(objects, SequenceTypes): objects = [objects] mode = self.factory.selection_mode indexes = [] flags = QtGui.QItemSelectionModel.ClearAndSelect # In the case of row or column selection, we need a dummy value for the # other dimension that has not been filtered. source_index = self.model.mapToSource(self.model.index(0, 0)) source_row, source_column = source_index.row(), source_index.column() # Selection mode is 'row' or 'rows' if mode.startswith('row'): flags |= QtGui.QItemSelectionModel.Rows items = self.items() for obj in objects: try: row = items.index(obj) except ValueError: continue indexes.append(self.source_model.index(row, source_column)) # Selection mode is 'column' or 'columns' elif mode.startswith('column'): flags |= QtGui.QItemSelectionModel.Columns for name in objects: column = self._column_index_from_name(name) if column != -1: indexes.append(self.source_model.index(source_row, column)) # Selection mode is 'cell' or 'cells' else: items = self.items() for obj, name in objects: try: row = items.index(obj) except ValueError: continue column = self._column_index_from_name(name) if column != -1: indexes.append(self.source_model.index(row, column)) # Perform the selection so that only one signal is emitted selection = QtGui.QItemSelection() for index in indexes: index = self.model.mapFromSource(index) if index.isValid(): self.table_view.setCurrentIndex(index) selection.select(index, index) smodel = self.table_view.selectionModel() try: smodel.blockSignals(not notify) if len(selection.indexes()): smodel.select(selection, flags) else: smodel.clear() finally: smodel.blockSignals(False) #------------------------------------------------------------------------- # Private methods: #------------------------------------------------------------------------- def _column_index_from_name(self, name): """Returns the index of the column with the given name or -1 if no column exists with that name.""" for i, column in enumerate(self.columns): if name == column.name: return i return -1 def _customize_filters(self, filter): """Allows the user to customize the current set of table filters.""" filter_editor = TableFilterEditor(editor=self) ui = filter_editor.edit_traits(parent=self.control) if ui.result: self.factory.filters = filter_editor.templates self.filter = filter_editor.selected_filter else: self.setx(filter=filter) def _update_filtering(self): """Update the filter summary and the filtered indices.""" items = self.items() num_items = len(items) f = self.filter if f is None: self._filtered_cache = None self.filtered_indices = range(num_items) self.filter_summary = 'All %i items' % num_items else: if not callable(f): f = f.filter self._filtered_cache = fc = [f(item) for item in items] self.filtered_indices = fi = [i for i, ok in enumerate(fc) if ok] self.filter_summary = '%i of %i items' % (len(fi), num_items) def _add_image(self, image_resource): """ Adds a new image to the image map. """ image = image_resource.create_icon() self.image_resources[image_resource] = image self.images[image_resource.name] = image return image def _get_image(self, image): """ Converts a user specified image to a QIcon. """ if isinstance(image, basestring): self.image = image image = self.image if isinstance(image, ImageResource): result = self.image_resources.get(image) if result is not None: return result return self._add_image(image) return self.images.get(image) #-- Trait Property getters/setters --------------------------------------- @cached_property def _get_selected_row(self): """Gets the selected row, or the first row if multiple rows are selected.""" mode = self.factory.selection_mode if mode.startswith('column'): return None elif mode == 'row': return self.selected try: if mode == 'rows': return self.selected[0] elif mode == 'cell': return self.selected[0] elif mode == 'cells': return self.selected[0][0] except IndexError: return None @cached_property def _get_selected_indices(self): """Gets the row,column indices which match the selected trait""" selection_items = self.table_view.selectionModel().selection() indices = self.model.mapSelectionToSource(selection_items).indexes() return [(index.row(), index.column()) for index in indices] def _set_selected_indices(self, indices): selected = [] for row, col in indices: selected.append((self.value[row], self.columns[col].name)) self.selected = selected self.set_selection(self.selected, False) return #-- Trait Change Handlers ------------------------------------------------ def _filter_changed(self, old_filter, new_filter): """Handles the current filter being changed.""" if not self._no_notify: if new_filter is customize_filter: do_later(self._customize_filters, old_filter) else: self._update_filtering() self.model.invalidate() self.set_selection(self.selected) def _update_columns(self): """Handle the column list being changed.""" self.table_view.setItemDelegate(TableDelegate(self.table_view)) for i, column in enumerate(self.columns): if column.renderer: self.table_view.setItemDelegateForColumn(i, column.renderer) self.model.reset() self.table_view.resizeColumnsToContents() if self.auto_size: self.table_view.resizeRowsToContents() def _selected_changed(self, new): """Handle the selected row/column/cell being changed externally.""" if not self._no_notify: self.set_selection(self.selected, notify=False) def _update_filter_changed(self): """ The filter has changed internally. """ self._filter_changed(self.filter, self.filter) #-- Event Handlers ------------------------------------------------------- def _on_row_selection(self, added, removed): """Handle the row selection being changed.""" items = self.items() indexes = self.table_view.selectionModel().selectedRows() if len(indexes): index = self.model.mapToSource(indexes[0]) selected = items[index.row()] else: selected = None self.setx(selected=selected) self.ui.evaluate(self.factory.on_select, self.selected) def _on_rows_selection(self, added, removed): """Handle the rows selection being changed.""" items = self.items() indexes = self.table_view.selectionModel().selectedRows() selected = [items[self.model.mapToSource(index).row()] for index in indexes] self.setx(selected=selected) self.ui.evaluate(self.factory.on_select, self.selected) def _on_column_selection(self, added, removed): """Handle the column selection being changed.""" indexes = self.table_view.selectionModel().selectedColumns() if len(indexes): index = self.model.mapToSource(indexes[0]) selected = self.columns[index.column()].name else: selected = '' self.setx(selected=selected) self.ui.evaluate(self.factory.on_select, self.selected) def _on_columns_selection(self, added, removed): """Handle the columns selection being changed.""" indexes = self.table_view.selectionModel().selectedColumns() selected = [self.columns[self.model.mapToSource(index).column()].name for index in indexes] self.setx(selected=selected) self.ui.evaluate(self.factory.on_select, self.selected) def _on_cell_selection(self, added, removed): """Handle the cell selection being changed.""" items = self.items() indexes = self.table_view.selectionModel().selectedIndexes() if len(indexes): index = self.model.mapToSource(indexes[0]) obj = items[index.row()] column_name = self.columns[index.column()].name else: obj = None column_name = '' selected = (obj, column_name) self.setx(selected=selected) self.ui.evaluate(self.factory.on_select, self.selected) def _on_cells_selection(self, added, removed): """Handle the cells selection being changed.""" items = self.items() indexes = self.table_view.selectionModel().selectedIndexes() selected = [] for index in indexes: index = self.model.mapToSource(index) obj = items[index.row()] column_name = self.columns[index.column()].name selected.append((obj, column_name)) self.setx(selected=selected) self.ui.evaluate(self.factory.on_select, self.selected) def _on_click(self, index): """Handle a cell being clicked.""" index = self.model.mapToSource(index) column = self.columns[index.column()] obj = self.items()[index.row()] # Fire the same event on the editor after mapping it to a model object # and column name: self.click = (obj, column) # Invoke the column's click handler: column.on_click(obj) def _on_dclick(self, index): """Handle a cell being double clicked.""" index = self.model.mapToSource(index) column = self.columns[index.column()] obj = self.items()[index.row()] # Fire the same event on the editor after mapping it to a model object # and column name: self.dclick = (obj, column) # Invoke the column's double-click handler: column.on_dclick(obj) def _on_context_insert(self): """Handle 'insert item' being selected from the header context menu.""" self.model.insertRow(self.header_row) def _on_context_append(self): """Handle 'add item' being selected from the empty space context menu.""" self.model.insertRow(self.model.rowCount()) def _on_context_remove(self): """Handle 'remove item' being selected from the header context menu.""" self.model.removeRow(self.header_row) def _on_context_move_up(self): """Handle 'move up' being selected from the header context menu.""" self.model.moveRow(self.header_row, self.header_row - 1) def _on_context_move_down(self): """Handle 'move down' being selected from the header context menu.""" self.model.moveRow(self.header_row, self.header_row + 1)
class InstSource(HasPrivateTraits): """Expose measurement information from a inst file. Parameters ---------- file : File Path to the BEM file (*.fif). Attributes ---------- fid : Array, shape = (3, 3) Each row contains the coordinates for one fiducial point, in the order Nasion, RAP, LAP. If no file is set all values are 0. """ file = File(exists=True, filter=['*.fif']) inst_fname = Property(Str, depends_on='file') inst_dir = Property(depends_on='file') inst = Property(depends_on='file') points_filter = Any(desc="Index to select a subset of the head shape " "points") n_omitted = Property(Int, depends_on=['points_filter']) # head shape inst_points = Property(depends_on='inst', desc="Head shape points in the " "inst file(n x 3 array)") points = Property(depends_on=['inst_points', 'points_filter'], desc="Head " "shape points selected by the filter (n x 3 array)") # fiducials fid_dig = Property(depends_on='inst', desc="Fiducial points " "(list of dict)") fid_points = Property(depends_on='fid_dig', desc="Fiducial points {ident: " "point} dict}") lpa = Property(depends_on='fid_points', desc="LPA coordinates (1 x 3 " "array)") nasion = Property(depends_on='fid_points', desc="Nasion coordinates (1 x " "3 array)") rpa = Property(depends_on='fid_points', desc="RPA coordinates (1 x 3 " "array)") view = View( VGroup(Item('file'), Item('inst_fname', show_label=False, style='readonly'))) @cached_property def _get_n_omitted(self): if self.points_filter is None: return 0 else: return np.sum(self.points_filter == False) # noqa: E712 @cached_property def _get_inst(self): if self.file: info = read_info(self.file) if info['dig'] is None: error( None, "The selected FIFF file does not contain " "digitizer information. Please select a different " "file.", "Error Reading FIFF File") self.reset_traits(['file']) else: return info @cached_property def _get_inst_dir(self): return os.path.dirname(self.file) @cached_property def _get_inst_fname(self): if self.file: return os.path.basename(self.file) else: return '-' @cached_property def _get_inst_points(self): if not self.inst: return np.zeros((1, 3)) points = np.array([ d['r'] for d in self.inst['dig'] if d['kind'] == FIFF.FIFFV_POINT_EXTRA ]) return points @cached_property def _get_points(self): if self.points_filter is None: return self.inst_points else: return self.inst_points[self.points_filter] @cached_property def _get_fid_dig(self): # noqa: D401 """Fiducials for info['dig'].""" if not self.inst: return [] dig = self.inst['dig'] dig = [d for d in dig if d['kind'] == FIFF.FIFFV_POINT_CARDINAL] return dig @cached_property def _get_fid_points(self): if not self.inst: return {} digs = dict((d['ident'], d) for d in self.fid_dig) return digs @cached_property def _get_nasion(self): if self.fid_points: return self.fid_points[FIFF.FIFFV_POINT_NASION]['r'][None, :] else: return np.zeros((1, 3)) @cached_property def _get_lpa(self): if self.fid_points: return self.fid_points[FIFF.FIFFV_POINT_LPA]['r'][None, :] else: return np.zeros((1, 3)) @cached_property def _get_rpa(self): if self.fid_points: return self.fid_points[FIFF.FIFFV_POINT_RPA]['r'][None, :] else: return np.zeros((1, 3)) def _file_changed(self): self.reset_traits(('points_filter', ))
class AdapterManager(HasTraits): """ A manager for adapter factories. """ #### 'AdapterManager' interface ########################################### # All registered type-scope factories by the type of object that they # adapt. # # The dictionary is keyed by the *name* of the class rather than the class # itself to allow for adapter factory proxies to register themselves # without having to load and create the factories themselves (i.e., to # allow us to lazily load adapter factories contributed by plugins). This # is a slight compromise as it is obviously geared towards use in Envisage, # but it doesn't affect the API other than allowing a class OR a string to # be passed to 'register_adapters'. # # { String adaptee_class_name : List(AdapterFactory) factories } type_factories = Property(Dict) # All registered instance-scope factories by the object that they adapt. # # { id(obj) : List(AdapterFactory) factories } instance_factories = Property(Dict) # The type system used by the manager (it determines 'is_a' relationships # and type MROs etc). By default we use standard Python semantics. type_system = Instance(AbstractTypeSystem, PythonTypeSystem()) #### Private interface #################################################### # All registered type-scope factories by the type of object that they # adapt. _type_factories = Dict # All registered instance-scope factories by the object that they adapt. _instance_factories = Dict ########################################################################### # 'AdapterManager' interface. ########################################################################### #### Properties ########################################################### def _get_type_factories(self): """ Returns all registered type-scope factories. """ return self._type_factories.copy() def _get_instance_factories(self): """ Returns all registered instance-scope factories. """ return self._instance_factories.copy() #### Methods ############################################################## def adapt(self, adaptee, target_class, *args, **kw): """ Returns an adapter that adapts an object to the target class. 'adaptee' is the object that we want to adapt. 'target_class' is the class that the adaptee should be adapted to. Returns None if no such adapter exists. """ # If the object is already an instance of the target class then we # simply return it. if self.type_system.is_a(adaptee, target_class): adapter = adaptee # Otherwise, look at each class in the adaptee's MRO to see if there # is an adapter factory registered that can adapt the object to the # target class. else: # Look for instance-scope adapters first. adapter = self._adapt_instance(adaptee, target_class, *args, **kw) # If no instance-scope adapter was found then try type-scope # adapters. if adapter is None: for adaptee_class in self.type_system.get_mro(type(adaptee)): adapter = self._adapt_type(adaptee, adaptee_class, target_class, *args, **kw) if adapter is not None: break return adapter def register_instance_adapters(self, factory, obj): """ Registers an instance-scope adapter factory. A factory can be in exactly one manager (as it uses the manager's type system). """ factories = self._instance_factories.setdefault(id(obj), []) factories.append(factory) # A factory can be in exactly one manager. factory.adapter_manager = self return def unregister_instance_adapters(self, factory, obj): """ Unregisters an instance scope adapter factory. A factory can be in exactly one manager (as it uses the manager's type system). """ factories = self._instance_factories.setdefault(id(obj), []) if factory in factories: factories.remove(factory) # A factory can be in exactly one manager. factory.adapter_manager = None return def register_type_adapters(self, factory, adaptee_class): """ Registers a type-scope adapter factory. 'adaptee_class' can be either a class object or the name of a class. A factory can be in exactly one manager (as it uses the manager's type system). """ if isinstance(adaptee_class, six.string_types): adaptee_class_name = adaptee_class else: adaptee_class_name = self._get_class_name(adaptee_class) factories = self._type_factories.setdefault(adaptee_class_name, []) factories.append(factory) # A factory can be in exactly one manager. factory.adapter_manager = self return def unregister_type_adapters(self, factory): """ Unregisters a type-scope adapter factory. """ for adaptee_class_name, factories in self._type_factories.items(): if factory in factories: factories.remove(factory) # The factory is no longer deemed to be part of this manager. factory.adapter_manager = None return #### DEPRECATED ########################################################### def register_adapters(self, factory, adaptee_class): """ Registers an adapter factory. 'adaptee_class' can be either a class object or the name of a class. A factory can be in exactly one manager (as it uses the manager's type system). """ warnings.warn('Use "register_type_adapters" instead.', DeprecationWarning) self.register_type_adapters(factory, adaptee_class) return def unregister_adapters(self, factory): """ Unregisters an adapter factory. """ warnings.warn('use "unregister_type_adapters" instead.', DeprecationWarning) self.unregister_type_adapters(factory) return ########################################################################### # Private interface. ########################################################################### def _adapt_instance(self, adaptee, target_class, *args, **kw): """ Returns an adapter that adaptes an object to the target class. Returns None if no such adapter exists. """ for factory in self._instance_factories.get(id(adaptee), []): adapter = factory.adapt(adaptee, target_class, *args, **kw) if adapter is not None: break else: adapter = None return adapter def _adapt_type(self, adaptee, adaptee_class, target_class, *args, **kw): """ Returns an adapter that adapts an object to the target class. Returns None if no such adapter exists. """ adaptee_class_name = self._get_class_name(adaptee_class) for factory in self._type_factories.get(adaptee_class_name, []): adapter = factory.adapt(adaptee, target_class, *args, **kw) if adapter is not None: break else: adapter = None return adapter def _get_class_name(self, klass): """ Returns the full class name for a class. """ return "%s.%s" % (klass.__module__, klass.__name__)
class Vis2D(HasStrictTraits): '''Each state and operator object can be associated with several visualization objects with a shortened class name Viz3D. In order to introduce a n independent class subsystem into the class structure, objects supporting visualization inherit from Visual3D which introduces a dictionary viz3d objects. ''' def setup(self): pass sim = WeakRef '''Root of the simulation to extract the data ''' vot = Float(0.0, time_change=True) '''Visual object time ''' viz2d_classes = Dict '''Visualization classes applicable to this object. ''' viz2d_class_names = Property(List(Str), depends_on='viz2d_classes') '''Keys of the viz2d classes ''' @cached_property def _get_viz2d_class_names(self): return list(self.viz2d_classes.keys()) selected_viz2d_class = Str def _selected_viz2d_class_default(self): if len(self.viz2d_class_names) > 0: return self.viz2d_class_names[0] else: return '' add_selected_viz2d = Button(label='Add plot viz2d') def _add_selected_viz2d_fired(self): viz2d_class_name = self.selected_viz2d_class self.add_viz2d(viz2d_class_name, '<unnamed>') def add_viz2d(self, class_name, name, **kw): if name == '': name = class_name viz2d_class = self.viz2d_classes[class_name] viz2d = viz2d_class(name=name, vis2d=self, **kw) self.viz2d.append(viz2d) if hasattr(self, 'ui') and self.ui: self.ui.viz_sheet.viz2d_list.append(viz2d) viz2d = List(Viz2D) actions = HGroup( UItem('add_selected_viz2d'), UItem('selected_viz2d_class', springy=True, editor=EnumEditor(name='object.viz2d_class_names', )), ) def plt(self, name, label=None): return Viz2DPlot(plot_fn=name, label=label, vis2d=self) view = View(Include('actions'), resizable=True)
class StatisticViewHandlerMixin(HasTraits): numeric_indices = Property(depends_on="model.statistic, model.subset") indices = Property(depends_on="model.statistic, model.subset") levels = Property(depends_on="model.statistic") # MAGIC: gets the value for the property numeric_indices def _get_numeric_indices(self): context = self.info.ui.context['context'] if not (context and context.statistics and self.model and self.model.statistic[0]): return [] stat = context.statistics[self.model.statistic] data = pd.DataFrame(index=stat.index) if self.model.subset: data = data.query(self.model.subset) if len(data) == 0: return [] names = list(data.index.names) for name in names: unique_values = data.index.get_level_values(name).unique() if len(unique_values) == 1: data.index = data.index.droplevel(name) data.reset_index(inplace=True) return [x for x in data if util.is_numeric(data[x])] # MAGIC: gets the value for the property indices def _get_indices(self): context = self.info.ui.context['context'] if not (context and context.statistics and self.model and self.model.statistic[0]): return [] stat = context.statistics[self.model.statistic] data = pd.DataFrame(index=stat.index) if self.model.subset: data = data.query(self.model.subset) if len(data) == 0: return [] names = list(data.index.names) for name in names: unique_values = data.index.get_level_values(name).unique() if len(unique_values) == 1: data.index = data.index.droplevel(name) return list(data.index.names) # MAGIC: gets the value for the property 'levels' # returns a Dict(Str, pd.Series) def _get_levels(self): context = self.info.ui.context['context'] if not (context and context.statistics and self.model and self.model.statistic[0]): return [] stat = context.statistics[self.model.statistic] index = stat.index names = list(index.names) for name in names: unique_values = index.get_level_values(name).unique() if len(unique_values) == 1: index = index.droplevel(name) names = list(index.names) ret = {} for name in names: ret[name] = pd.Series(index.get_level_values(name)).sort_values() ret[name] = pd.Series(ret[name].unique()) return ret
class MassBalanceAnalyzer(HasStrictTraits): """ Tool to compare loaded mass from continuous data to the method/solution information. """ #: Target information we are evaluating mass balance for target_experiment = Instance(Experiment) #: Time before which to truncate experiment data time_of_origin = Instance(UnitScalar) #: Continuous UV data to evaluate loaded mass from (passed separately to #: support doing analysis while building an exp) continuous_data = Instance(XYData) #: Loaded mass, as computed from method information mass_from_method = Instance(UnitScalar) #: Loaded mass, as computed from UV continuous data mass_from_uv = Instance(UnitScalar) #: Current load product concentration in target experiment current_concentration = Property(Instance(UnitScalar)) #: Current load step volume in target experiment current_volume = Property(Instance(UnitScalar)) #: Threshold for relative difference to define the data as imbalanced balance_threshold = Float #: Whether the target experiment's data is balanced (in agreement) balanced = Bool # View elements ----------------------------------------------------------- #: Plot of the UV data plot = Instance(Plot) #: Data container for the plot plot_data = Instance(ArrayPlotData) # Traits methods ---------------------------------------------------------- def __init__(self, **traits): if "target_experiment" not in traits: msg = "A target experiment is required to create a {}" msg = msg.format(self.__class__.__name__) logger.exception(msg) raise ValueError(msg) super(MassBalanceAnalyzer, self).__init__(**traits) use_exp_data = (self.continuous_data is None and self.target_experiment.output is not None) # If the tool is created for a fully created experiment, extract the # continuous data from the experiment, if use_exp_data: data_dict = self.target_experiment.output.continuous_data self.continuous_data = data_dict[UV_DATA_KEY] def traits_view(self): view = KromView( Item("plot", editor=ComponentEditor(), show_label=False)) return view # Public interface -------------------------------------------------------- def analyze(self): """ Analyze all available loads, and compare to experiments cont. data. Returns ------- list Returns list of load names that are not balanced. """ product = self.target_experiment.product if len(product.product_components) > 1: msg = "Unable to analyze the loaded mass for multi-component " \ "products, since we don't have a product extinction coeff." logger.exception(msg) raise ValueError(msg) msg = "Analyzing mass balance for {}".format( self.target_experiment.name) logger.debug(msg) self.mass_from_method = self.loaded_mass_from_method() self.mass_from_uv = self.loaded_mass_from_uv() if self.mass_from_uv is None: self.balanced = True return diff = abs(self.mass_from_method - self.mass_from_uv) rel_diff = diff / self.mass_from_method msg = "Loaded mass computed from method and UV are different by " \ "{:.3g}%".format(float(rel_diff)*100) logger.debug(msg) self.balanced = float(rel_diff) <= self.balance_threshold def loaded_mass_from_method(self): """ Returns loaded product mass in grams, as compute from method data. """ vol = self.current_volume if units_almost_equal(vol, column_volumes): vol = float(vol) * self.target_experiment.column.volume elif has_volume_units(vol): pass else: msg = "Unexpected unit for the load step volume: {}" msg = msg.format(vol.units.label) logger.exception(msg) raise NotImplementedError(msg) mass = vol * self.current_concentration return convert_units(mass, tgt_unit="gram") def loaded_mass_from_uv(self): """ Returns loaded product mass in grams, as compute from UV data. """ if not self.continuous_data: return data = self.continuous_data product = self.target_experiment.product ext_coeff = product.product_components[0].extinction_coefficient method_step_times = self.target_experiment.method_step_boundary_times t_stop = UnitScalar(method_step_times[-1], units=method_step_times.units) mass = compute_mass_from_abs_data(data, ext_coeff=ext_coeff, experim=self.target_experiment, t_start=self.time_of_origin, t_stop=t_stop) return convert_units(mass, tgt_unit="gram") # Adjustment methods ------------------------------------------------------ def compute_loaded_vol(self, tgt_units=column_volumes): """ Compute the load step volume that would match the UV data at constant load solution concentration. Returns ------- UnitScalar Load step volume, in CV, that would be needed to match the UV data. """ if tgt_units not in [column_volumes, g_per_liter_resin]: msg = "Supported target units are CV and g/Liter of resin but " \ "got {}.".format(tgt_units.label) logger.debug(msg) raise ValueError(msg) target_mass = self.mass_from_uv col_volume = self.target_experiment.column.volume concentration = self.current_concentration vol = target_mass / concentration / col_volume # Test equality on the labels since CV and g_per_liter_resin are equal # from a derivation point of view (dimensionless) if tgt_units.label == g_per_liter_resin.label: vol = float(vol * concentration) vol = UnitScalar(vol, units=g_per_liter_resin) else: vol = UnitScalar(vol, units=column_volumes) return vol def compute_concentration(self): """ Compute the load solution concentration that would match the UV data at constant load step volume. Returns ------- UnitScalar Load solution concentration, in g/L, that would be needed to match the UV data. """ target_mass = self.mass_from_uv vol = self.current_volume if units_almost_equal(vol, column_volumes): vol = float(vol) * self.target_experiment.column.volume concentration = target_mass / vol return convert_units(concentration, tgt_unit="g/L") # Traits property getters/setters ----------------------------------------- def _get_current_volume(self): load_step = self.target_experiment.method.load return load_step.volume def _get_current_concentration(self): load_sol = self.target_experiment.method.load.solutions[0] comp_concs = load_sol.product_component_concentrations return unitarray_to_unitted_list(comp_concs)[0] # Traits listeners -------------------------------------------------------- def _continuous_data_changed(self): if self.plot_data is None: return self.plot_data.update_data(x=self.continuous_data.x_data, y=self.continuous_data.y_data) # Traits initialization methods ------------------------------------------- def _balance_threshold_default(self): from kromatography.utils.app_utils import get_preferences prefs = get_preferences() return prefs.file_preferences.exp_importer_mass_threshold def _plot_default(self): self.plot_data = ArrayPlotData(x=self.continuous_data.x_data, y=self.continuous_data.y_data) plot = Plot(self.plot_data) plot.plot(("x", "y")) x_units = self.continuous_data.x_metadata["units"] y_units = self.continuous_data.y_metadata["units"] plot.index_axis.title = "Time ({})".format(x_units) plot.value_axis.title = "UV Absorption ({})".format(y_units) # Add zoom and pan tools to the plot zoom = BetterSelectingZoom(component=plot, tool_mode="box", always_on=False) plot.overlays.append(zoom) plot.tools.append(PanTool(component=plot)) return plot
class VUMeter(Component): # Value expressed in dB db = Property(Float) # Value expressed as a percent. percent = Range(low=0.0) # The maximum value to be display in the VU Meter, expressed as a percent. max_percent = Float(150.0) # Angle (in degrees) from a horizontal line through the hinge of the # needle to the edge of the meter axis. angle = Float(45.0) # Values of the percentage-based ticks; these are drawn and labeled along # the bottom of the curve axis. percent_ticks = List(list(range(0, 101, 20))) # Text to write in the middle of the VU Meter. text = Str("VU") # Font used to draw `text`. text_font = KivaFont("modern 48") # Font for the db tick labels. db_tick_font = KivaFont("modern 16") # Font for the percent tick labels. percent_tick_font = KivaFont("modern 12") # beta is the fraction of the of needle that is "hidden". # beta == 0 puts the hinge point of the needle on the bottom # edge of the window. Values that result in a decent looking # meter are 0 < beta < .65. # XXX needs a better name! _beta = Float(0.3) # _outer_radial_margin is the radial extent beyond the circular axis # to include in calculations of the space required for the meter. # This allows room for the ticks and labels. _outer_radial_margin = Float(60.0) # The angle (in radians) of the span of the curve axis. _phi = Property(Float, depends_on=['angle']) # This is the radius of the circular axis (in screen coordinates). _axis_radius = Property(Float, depends_on=['_phi', 'width', 'height']) #--------------------------------------------------------------------- # Trait Property methods #--------------------------------------------------------------------- def _get_db(self): db = percent_to_db(self.percent) return db def _set_db(self, value): self.percent = db_to_percent(value) def _get__phi(self): phi = math.pi * (180.0 - 2 * self.angle) / 180.0 return phi def _get__axis_radius(self): M = self._outer_radial_margin beta = self._beta w = self.width h = self.height phi = self._phi R1 = w / (2 * math.sin(phi / 2)) - M R2 = (h - M) / (1 - beta * math.cos(phi / 2)) R = min(R1, R2) return R #--------------------------------------------------------------------- # Trait change handlers #--------------------------------------------------------------------- def _anytrait_changed(self): self.request_redraw() #--------------------------------------------------------------------- # Component API #--------------------------------------------------------------------- def _draw_mainlayer(self, gc, view_bounds=None, mode="default"): beta = self._beta phi = self._phi w = self.width M = self._outer_radial_margin R = self._axis_radius # (ox, oy) is the position of the "hinge point" of the needle # (i.e. the center of rotation). For beta > ~0, oy is negative, # so this point is below the visible region. ox = self.x + self.width // 2 oy = -beta * R * math.cos(phi / 2) + 1 left_theta = math.radians(180 - self.angle) right_theta = math.radians(self.angle) # The angle of the 100% position. nominal_theta = self._percent_to_theta(100.0) # The color of the axis for percent > 100. red = (0.8, 0, 0) with gc: gc.set_antialias(True) # Draw everything relative to the center of the circles. gc.translate_ctm(ox, oy) # Draw the primary ticks and tick labels on the curved axis. gc.set_fill_color((0, 0, 0)) gc.set_font(self.db_tick_font) for db in [-20, -10, -7, -5, -3, -2, -1, 0, 1, 2, 3]: db_percent = db_to_percent(db) theta = self._percent_to_theta(db_percent) x1 = R * math.cos(theta) y1 = R * math.sin(theta) x2 = (R + 0.3 * M) * math.cos(theta) y2 = (R + 0.3 * M) * math.sin(theta) gc.set_line_width(2.5) gc.move_to(x1, y1) gc.line_to(x2, y2) gc.stroke_path() text = str(db) if db > 0: text = '+' + text self._draw_rotated_label(gc, text, theta, R + 0.4 * M) # Draw the secondary ticks on the curve axis. for db in [-15, -9, -8, -6, -4, -0.5, 0.5]: ##db_percent = 100 * math.pow(10.0, db / 20.0) db_percent = db_to_percent(db) theta = self._percent_to_theta(db_percent) x1 = R * math.cos(theta) y1 = R * math.sin(theta) x2 = (R + 0.2 * M) * math.cos(theta) y2 = (R + 0.2 * M) * math.sin(theta) gc.set_line_width(1.0) gc.move_to(x1, y1) gc.line_to(x2, y2) gc.stroke_path() # Draw the percent ticks and label on the bottom of the # curved axis. gc.set_font(self.percent_tick_font) gc.set_fill_color((0.5, 0.5, 0.5)) gc.set_stroke_color((0.5, 0.5, 0.5)) percents = self.percent_ticks for tick_percent in percents: theta = self._percent_to_theta(tick_percent) x1 = (R - 0.15 * M) * math.cos(theta) y1 = (R - 0.15 * M) * math.sin(theta) x2 = R * math.cos(theta) y2 = R * math.sin(theta) gc.set_line_width(2.0) gc.move_to(x1, y1) gc.line_to(x2, y2) gc.stroke_path() text = str(tick_percent) if tick_percent == percents[-1]: text = text + "%" self._draw_rotated_label(gc, text, theta, R - 0.3 * M) if self.text: gc.set_font(self.text_font) tx, ty, tw, th = gc.get_text_extent(self.text) gc.set_fill_color((0, 0, 0, 0.25)) gc.set_text_matrix(affine.affine_from_rotation(0)) gc.set_text_position(-0.5 * tw, (0.75 * beta + 0.25) * R) gc.show_text(self.text) # Draw the red curved axis. gc.set_stroke_color(red) w = 10 gc.set_line_width(w) gc.arc(0, 0, R + 0.5 * w - 1, right_theta, nominal_theta) gc.stroke_path() # Draw the black curved axis. w = 4 gc.set_line_width(w) gc.set_stroke_color((0, 0, 0)) gc.arc(0, 0, R + 0.5 * w - 1, nominal_theta, left_theta) gc.stroke_path() # Draw the filled arc at the bottom. gc.set_line_width(2) gc.set_stroke_color((0, 0, 0)) gc.arc(0, 0, beta * R, math.radians(self.angle), math.radians(180 - self.angle)) gc.stroke_path() gc.set_fill_color((0, 0, 0, 0.25)) gc.arc(0, 0, beta * R, math.radians(self.angle), math.radians(180 - self.angle)) gc.fill_path() # Draw the needle. percent = self.percent # If percent exceeds max_percent, the needle is drawn at max_percent. if percent > self.max_percent: percent = self.max_percent needle_theta = self._percent_to_theta(percent) gc.rotate_ctm(needle_theta - 0.5 * math.pi) self._draw_vertical_needle(gc) #--------------------------------------------------------------------- # Private methods #--------------------------------------------------------------------- def _draw_vertical_needle(self, gc): """ Draw the needle of the meter, pointing straight up. """ beta = self._beta R = self._axis_radius end_y = beta * R blob_y = R - 0.6 * self._outer_radial_margin tip_y = R + 0.2 * self._outer_radial_margin lw = 5 with gc: gc.set_alpha(1) gc.set_fill_color((0, 0, 0)) # Draw the needle from the bottom to the blob. gc.set_line_width(lw) gc.move_to(0, end_y) gc.line_to(0, blob_y) gc.stroke_path() # Draw the thin part of the needle from the blob to the tip. gc.move_to(lw, blob_y) control_y = blob_y + 0.25 * (tip_y - blob_y) gc.quad_curve_to(0.2 * lw, control_y, 0, tip_y) gc.quad_curve_to(-0.2 * lw, control_y, -lw, blob_y) gc.line_to(lw, blob_y) gc.fill_path() # Draw the blob on the needle. gc.arc(0, blob_y, 6.0, 0, 2 * math.pi) gc.fill_path() def _draw_rotated_label(self, gc, text, theta, radius): tx, ty, tw, th = gc.get_text_extent(text) rr = math.sqrt(radius**2 + (0.5 * tw)**2) dtheta = math.atan2(0.5 * tw, radius) text_theta = theta + dtheta x = rr * math.cos(text_theta) y = rr * math.sin(text_theta) rot_theta = theta - 0.5 * math.pi with gc: gc.set_text_matrix(affine.affine_from_rotation(rot_theta)) gc.set_text_position(x, y) gc.show_text(text) def _percent_to_theta(self, percent): """ Convert percent to the angle theta, in radians. theta is the angle of the needle measured counterclockwise from the horizontal (i.e. the traditional angle of polar coordinates). """ angle = (self.angle + (180.0 - 2 * self.angle) * (self.max_percent - percent) / self.max_percent) theta = math.radians(angle) return theta def _db_to_theta(self, db): """ Convert db to the angle theta, in radians. """ percent = db_to_percent(db) theta = self._percent_to_theta(percent) return theta
class ImageActor(Module): # An image actor. actor = Instance(tvtk.ImageActor, allow_none=False, record=True) input_info = PipelineInfo(datasets=['image_data'], attribute_types=['any'], attributes=['any']) # An ImageMapToColors TVTK filter to adapt datasets without color # information image_map_to_color = Instance(tvtk.ImageMapToColors, (), allow_none=False, record=True) map_scalars_to_color = Bool _force_map_scalars_to_color = Property(depends_on='module_manager.source') ######################################## # The view of this module. view = View(Group(Item(name='actor', style='custom', resizable=True), show_labels=False, label='Actor'), Group( Group( Item('map_scalars_to_color', enabled_when='not _force_map_scalars_to_color')), Item('image_map_to_color', style='custom', enabled_when='map_scalars_to_color', show_label=False), label='Map Scalars', ), width=500, height=600, resizable=True) ###################################################################### # `Module` interface ###################################################################### def setup_pipeline(self): self.actor = tvtk.ImageActor() @on_trait_change('map_scalars_to_color,' 'image_map_to_color.[output_format,pass_alpha_to_output],' 'module_manager.scalar_lut_manager.lut_mode,' 'module_manager.vector_lut_manager.lut_mode') def update_pipeline(self): """Override this method so that it *updates* the tvtk pipeline when data upstream is known to have changed. """ mm = self.module_manager if mm is None: return src = mm.source if self._force_map_scalars_to_color: self.set(map_scalars_to_color=True, trait_change_notify=False) if self.map_scalars_to_color: self.configure_connection(self.image_map_to_color, src) self.image_map_to_color.lookup_table = mm.scalar_lut_manager.lut self.image_map_to_color.update() self.configure_input_data(self.actor, self.image_map_to_color.output) else: self.configure_input_data(self.actor, src.outputs[0]) self.pipeline_changed = True def update_data(self): """Override this method so that it flushes the vtk pipeline if that is necessary. """ # Just set data_changed, the component should do the rest. self.data_changed = True ###################################################################### # Non-public methods. ###################################################################### def _actor_changed(self, old, new): if old is not None: self.actors.remove(old) old.on_trait_change(self.render, remove=True) self.actors.append(new) new.on_trait_change(self.render) def _get__force_map_scalars_to_color(self): mm = self.module_manager if mm is None: return False src = mm.source return not isinstance(src.outputs[0].point_data.scalars, tvtk.UnsignedCharArray)
class MATSXDMicroplaneDamageFatigueJir(MATSEvalMicroplaneFatigue): ''' Microplane Damage Fatigue Model. ''' #------------------------------------------------------------------------- # Setup for computation within a supplied spatial context #------------------------------------------------------------------------- D4_e = Property def _get_D4_e(self): # Return the elasticity tensor return self.elasticity_tensors #------------------------------------------------------------------------- # MICROPLANE-Kinematic constraints #------------------------------------------------------------------------- # get the dyadic product of the microplane normals _MPNN = Property(depends_on='n_mp') @cached_property def _get__MPNN(self): # dyadic product of the microplane normals MPNN_nij = einsum('ni,nj->nij', self._MPN, self._MPN) return MPNN_nij # get the third order tangential tensor (operator) for each microplane _MPTT = Property(depends_on='n_mp') @cached_property def _get__MPTT(self): # Third order tangential tensor for each microplane delta = identity(3) MPTT_nijr = 0.5 * (einsum('ni,jr -> nijr', self._MPN, delta) + einsum('nj,ir -> njir', self._MPN, delta) - 2 * einsum('ni,nj,nr -> nijr', self._MPN, self._MPN, self._MPN)) return MPTT_nijr def _get_e_vct_arr(self, eps_eng): # Projection of apparent strain onto the individual microplanes e_ni = einsum('nj,ji->ni', self._MPN, eps_eng) return e_ni def _get_e_N_arr(self, e_vct_arr): # get the normal strain array for each microplane eN_n = einsum('ni,ni->n', e_vct_arr, self._MPN) return eN_n def _get_e_T_vct_arr(self, e_vct_arr): # get the tangential strain vector array for each microplane eN_n = self._get_e_N_arr(e_vct_arr) eN_vct_ni = einsum('n,ni->ni', eN_n, self._MPN) return e_vct_arr - eN_vct_ni #------------------------------------------------- # Alternative methods for the kinematic constraint #------------------------------------------------- def _get_e_N_arr_2(self, eps_eng): return einsum('nij,ij->n', self._MPNN, eps_eng) def _get_e_T_vct_arr_2(self, eps_eng): MPTT_ijr = self._get__MPTT() return einsum('nijr,ij->nr', MPTT_ijr, eps_eng) def _get_e_vct_arr_2(self, eps_eng): return self._e_N_arr_2 * self._MPN + self._e_t_vct_arr_2 #-------------------------------------------------------- # return the state variables (Damage , inelastic strains) #-------------------------------------------------------- def _get_state_variables(self, sctx, eps_app_eng, sigma_kk): e_N_arr = self._get_e_N_arr_2(eps_app_eng) e_T_vct_arr = self._get_e_T_vct_arr_2(eps_app_eng) sctx_arr = zeros_like(sctx) sctx_N = self.get_normal_Law(e_N_arr, sctx) sctx_arr[:, 0:5] = sctx_N sctx_tangential = self.get_tangential_Law(e_T_vct_arr, sctx, sigma_kk) sctx_arr[:, 5:14] = sctx_tangential return sctx_arr #----------------------------------------------------------------- # Returns a list of the plastic normal strain for all microplanes. #----------------------------------------------------------------- def _get_eps_N_p_arr(self, sctx, eps_app_eng, sigma_kk): eps_N_p = self._get_state_variables(sctx, eps_app_eng, sigma_kk)[:, 4] return eps_N_p #---------------------------------------------------------------- # Returns a list of the sliding strain vector for all microplanes. #---------------------------------------------------------------- def _get_eps_T_pi_arr(self, sctx, eps_app_eng, sigma_kk): eps_T_pi_vct_arr = self._get_state_variables( sctx, eps_app_eng, sigma_kk)[:, 10:13] return eps_T_pi_vct_arr #--------------------------------------------------------------------- # Extra homogenization of damge tensor in case of two damage parameters #--------------------------------------------------------------------- def _get_beta_N_arr(self, sctx, eps_app_eng, sigma_kk): # Returns a list of the normal integrity factors for all microplanes. beta_N_arr = sqrt(1 - self._get_state_variables(sctx, eps_app_eng, sigma_kk)[:, 0]) return beta_N_arr def _get_beta_T_arr(self, sctx, eps_app_eng, sigma_kk): # Returns a list of the tangential integrity factors for all # microplanes. beta_T_arr = sqrt(1 - self._get_state_variables(sctx, eps_app_eng, sigma_kk)[:, 5]) return beta_T_arr def _get_beta_tns(self, sctx, eps_app_eng, sigma_kk): # Returns the 4th order damage tensor 'beta4' using #(cf. [Baz99], Eq.(63)) delta = identity(3) beta_N_n = self._get_beta_N_arr(sctx, eps_app_eng, sigma_kk) beta_T_n = self._get_beta_T_arr(sctx, eps_app_eng, sigma_kk) beta_ijkl = einsum('n, n, ni, nj, nk, nl -> ijkl', self._MPW, beta_N_n, self._MPN, self._MPN, self._MPN, self._MPN) + \ 0.25 * (einsum('n, n, ni, nk, jl -> ijkl', self._MPW, beta_T_n, self._MPN, self._MPN, delta) + einsum('n, n, ni, nl, jk -> ijkl', self._MPW, beta_T_n, self._MPN, self._MPN, delta) + einsum('n, n, nj, nk, il -> ijkl', self._MPW, beta_T_n, self._MPN, self._MPN, delta) + einsum('n, n, nj, nl, ik -> ijkl', self._MPW, beta_T_n, self._MPN, self._MPN, delta) - 4 * einsum('n, n, ni, nj, nk, nl -> ijkl', self._MPW, beta_T_n, self._MPN, self._MPN, self._MPN, self._MPN)) # print 'beta_ijkl =', beta_ijkl return beta_ijkl #------------------------------------------------------------- # Returns a list of the integrity factors for all microplanes. #------------------------------------------------------------- def _get_phi_arr(self, sctx, eps_app_eng, sigma_kk): w_n = self._get_state_variables(sctx, eps_app_eng, sigma_kk)[:, 0] w_T = self._get_state_variables(sctx, eps_app_eng, sigma_kk)[:, 5] w = zeros(self.n_mp) # w = maximum(w_n, w_T) eig = np.linalg.eig(eps_app_eng)[0] ter_1 = np.sum(eig) if ter_1 > 0.0: w = w_n else: w = w_T phi_arr = sqrt(1.0 - w) return phi_arr #---------------------------------------------- # Returns the 2nd order damage tensor 'phi_mtx' #---------------------------------------------- def _get_phi_mtx(self, sctx, eps_app_eng, sigma_kk): # scalar integrity factor for each microplane phi_arr = self._get_phi_arr(sctx, eps_app_eng, sigma_kk) # integration terms for each microplanes phi_ij = einsum('n,n,nij->ij', phi_arr, self._MPW, self._MPNN) return phi_ij #---------------------------------------------------------------------- # Returns the 4th order damage tensor 'beta4' using sum-type symmetrization #(cf. [Jir99], Eq.(21)) #---------------------------------------------------------------------- def _get_beta_tns_sum_type(self, sctx, eps_app_eng, sigma_kk): delta = identity(3) phi_mtx = self._get_phi_mtx(sctx, eps_app_eng, sigma_kk) # use numpy functionality (einsum) to evaluate [Jir99], Eq.(21) beta_ijkl = 0.25 * (einsum('ik,jl->ijkl', phi_mtx, delta) + einsum('il,jk->ijkl', phi_mtx, delta) + einsum('jk,il->ijkl', phi_mtx, delta) + einsum('jl,ik->ijkl', phi_mtx, delta)) return beta_ijkl #---------------------------------------------------------------------- # Returns the 4th order damage tensor 'beta4' using product-type symmetrization #(cf. [Baz97], Eq.(87)) #---------------------------------------------------------------------- def _get_beta_tns_product_type(self, sctx, eps_app_eng, sigma_kk): delta = identity(3) phi_mtx = self._get_phi_mtx(sctx, eps_app_eng, sigma_kk) n_dim = 3 phi_eig_value, phi_eig_mtx = eigh(phi_mtx) phi_eig_value_real = array([pe.real for pe in phi_eig_value]) phi_pdc_mtx = zeros((n_dim, n_dim), dtype=float) for i in range(n_dim): phi_pdc_mtx[i, i] = phi_eig_value_real[i] # w_mtx = tensorial square root of the second order damage tensor: w_pdc_mtx = sqrt(phi_pdc_mtx) # transform the matrix w back to x-y-coordinates: w_mtx = einsum('ik,kl,lj -> ij', phi_eig_mtx, w_pdc_mtx, phi_eig_mtx) #w_mtx = dot(dot(phi_eig_mtx, w_pdc_mtx), transpose(phi_eig_mtx)) beta_ijkl = 0.5 * \ (einsum('ik,jl -> ijkl', w_mtx, w_mtx) + einsum('il,jk -> ijkl', w_mtx, w_mtx)) return beta_ijkl #----------------------------------------------------------- # Integration of the (inelastic) strains for each microplane #----------------------------------------------------------- def _get_eps_p_mtx(self, sctx, eps_app_eng, sigma_kk): # plastic normal strains eps_N_P_n = self._get_eps_N_p_arr(sctx, eps_app_eng, sigma_kk) # sliding tangential strains eps_T_pi_ni = self._get_eps_T_pi_arr(sctx, eps_app_eng, sigma_kk) delta = identity(3) #eps_T_pi_ni = np.zeros_like(eps_T_pi_ni) # 2-nd order plastic (inelastic) tensor eps_p_ij = einsum('n,n,ni,nj -> ij', self._MPW, eps_N_P_n, self._MPN, self._MPN) + \ 0.5 * (einsum('n,nr,ni,rj->ij', self._MPW, eps_T_pi_ni, self._MPN, delta) + einsum('n,nr,nj,ri->ij', self._MPW, eps_T_pi_ni, self._MPN, delta)) # print 'eps_p', eps_p_ij return eps_p_ij #------------------------------------------------------------------------- # Evaluation - get the corrector and predictor #------------------------------------------------------------------------- def get_corr_pred(self, sctx, eps_app_eng, sigma_kk): # Corrector predictor computation. #------------------------------------------------------------------ # Damage tensor (4th order) using product- or sum-type symmetrization: #------------------------------------------------------------------ beta_ijkl = self._get_beta_tns_sum_type( sctx, eps_app_eng, sigma_kk) # beta_ijkl = self._get_beta_tns( # sctx, eps_app_eng, sigma_kk) #------------------------------------------------------------------ # Damaged stiffness tensor calculated based on the damage tensor beta4: #------------------------------------------------------------------ D4_mdm_ijmn = einsum( 'ijkl,klsr,mnsr->ijmn', beta_ijkl, self.D4_e, beta_ijkl) #---------------------------------------------------------------------- # Return stresses (corrector) and damaged secant stiffness matrix (predictor) #---------------------------------------------------------------------- # plastic strain tensor eps_p_ij = self._get_eps_p_mtx(sctx, eps_app_eng, sigma_kk) # elastic strain tensor eps_e_mtx = eps_app_eng - eps_p_ij # print 'eps_e', eps_e_mtx # calculation of the stress tensor sig_eng = einsum('ijmn,mn -> ij', D4_mdm_ijmn, eps_e_mtx) return sig_eng, D4_mdm_ijmn
class MATS3DMicroplaneDamageJir(MATSXDMicroplaneDamageFatigueJir): # implements(IMATSEval) #----------------------------------------------- # number of microplanes - currently fixed for 3D #----------------------------------------------- n_mp = Constant(28) #----------------------------------------------- # get the normal vectors of the microplanes #----------------------------------------------- _MPN = Property(depends_on='n_mp') @cached_property def _get__MPN(self): # microplane normals: return array([[.577350259, .577350259, .577350259], [.577350259, .577350259, -.577350259], [.577350259, -.577350259, .577350259], [.577350259, -.577350259, -.577350259], [.935113132, .250562787, .250562787], [.935113132, .250562787, -.250562787], [.935113132, -.250562787, .250562787], [.935113132, -.250562787, -.250562787], [.250562787, .935113132, .250562787], [.250562787, .935113132, -.250562787], [.250562787, -.935113132, .250562787], [.250562787, -.935113132, -.250562787], [.250562787, .250562787, .935113132], [.250562787, .250562787, -.935113132], [.250562787, -.250562787, .935113132], [.250562787, -.250562787, -.935113132], [.186156720, .694746614, .694746614], [.186156720, .694746614, -.694746614], [.186156720, -.694746614, .694746614], [.186156720, -.694746614, -.694746614], [.694746614, .186156720, .694746614], [.694746614, .186156720, -.694746614], [.694746614, -.186156720, .694746614], [.694746614, -.186156720, -.694746614], [.694746614, .694746614, .186156720], [.694746614, .694746614, -.186156720], [.694746614, -.694746614, .186156720], [.694746614, -.694746614, -.186156720]]) #------------------------------------- # get the weights of the microplanes #------------------------------------- _MPW = Property(depends_on='n_mp') @cached_property def _get__MPW(self): # Note that the values in the array must be multiplied by 6 (cf. [Baz05])! # The sum of of the array equals 0.5. (cf. [BazLuz04])) # The values are given for an Gaussian integration over the unit # hemisphere. return array([.0160714276, .0160714276, .0160714276, .0160714276, .0204744730, .0204744730, .0204744730, .0204744730, .0204744730, .0204744730, .0204744730, .0204744730, .0204744730, .0204744730, .0204744730, .0204744730, .0158350505, .0158350505, .0158350505, .0158350505, .0158350505, .0158350505, .0158350505, .0158350505, .0158350505, .0158350505, .0158350505, .0158350505]) * 6.0 #------------------------------------------------------------------------- # Cached elasticity tensors #------------------------------------------------------------------------- elasticity_tensors = Property( depends_on='E, nu, dimensionality, stress_state') @cached_property def _get_elasticity_tensors(self): ''' Intialize the fourth order elasticity tensor for 3D or 2D plane strain or 2D plane stress ''' # ---------------------------------------------------------------------------- # Lame constants calculated from E and nu # ---------------------------------------------------------------------------- # first Lame paramter la = self.E * self.nu / ((1 + self.nu) * (1 - 2 * self.nu)) # second Lame parameter (shear modulus) mu = self.E / (2 + 2 * self.nu) # ----------------------------------------------------------------------------------------------------- # Get the fourth order elasticity and compliance tensors for the 3D-case # ----------------------------------------------------------------------------------------------------- # construct the elasticity tensor (using Numpy - einsum function) delta = identity(3) D_ijkl = (einsum(',ij,kl->ijkl', la, delta, delta) + einsum(',ik,jl->ijkl', mu, delta, delta) + einsum(',il,jk->ijkl', mu, delta, delta)) return D_ijkl #------------------------------------------------------------------------- # Dock-based view with its own id #------------------------------------------------------------------------- traits_view = View(Include('polar_fn_group'), dock='tab', id='ibvpy.mats.mats3D.mats_3D_cmdm.MATS3D_cmdm', kind='modal', resizable=True, scrollable=True, width=0.6, height=0.8, buttons=['OK', 'Cancel'] )
class SkeletonTask(Task): """ A simple task for opening a blank editor. """ #### Task interface ####################################################### id = 'omnivore.framework.text_edit_task' name = 'Skeleton' active_editor = Property(Instance(IEditor), depends_on='editor_area.active_editor') editor_area = Instance(IEditorAreaPane) menu_bar = SMenuBar( SMenu(TaskAction(name='New', method='new', accelerator='Ctrl+N'), id='File', name='&File'), SMenu(id='Edit', name='&Edit'), SMenu( #DockPaneToggleGroup(), TaskToggleGroup(), id='View', name='&View')) tool_bars = [ SToolBar(TaskAction(method='new', tooltip='New file', image=ImageResource('document_new')), image_size=(32, 32)), ] ########################################################################### # 'Task' interface. ########################################################################### def _default_layout_default(self): return TaskLayout(left=VSplitter( PaneItem('text_edit.pane1'), HSplitter( PaneItem('text_edit.pane2'), PaneItem('text_edit.pane3'), ), )) def create_central_pane(self): """ Create the central pane: the text editor. """ self.editor_area = EditorAreaPane() return self.editor_area def create_dock_panes(self): """ Create the file browser and connect to its double click event. """ return [SkeletonPane1(), SkeletonPane2(), SkeletonPane3()] ########################################################################### # 'ExampleTask' interface. ########################################################################### def new(self): """ Opens a new empty window """ editor = Editor() self.editor_area.add_editor(editor) self.editor_area.activate_editor(editor) self.activated() #### Trait property getter/setters ######################################## def _get_active_editor(self): if self.editor_area is not None: return self.editor_area.active_editor return None
class MATSXDMicroplaneDamageFatigueWu(MATSEvalMicroplaneFatigue): ''' Microplane Damage Model. ''' # specification of the model dimension (2D, 3D) n_dim = Int # specification of number of engineering strain and stress components n_eng = Int #------------------------------------------------------------------------- # Configuration parameters #------------------------------------------------------------------------- model_version = Enum("compliance", "stiffness") symmetrization = Enum("product-type", "sum-type") regularization = Bool(False, desc='Flag to use the element length projection' ' in the direction of principle strains', enter_set=True, auto_set=False) elastic_debug = Bool( False, desc='Switch to elastic behavior - used for debugging', auto_set=False) double_constraint = Bool( False, desc= 'Use double constraint to evaluate microplane elastic and fracture energy (Option effects only the response tracers)', auto_set=False) #------------------------------------------------------------------------- # View specification #------------------------------------------------------------------------- config_param_vgroup = Group( Item('model_version', style='custom'), # Item('stress_state', style='custom'), Item('symmetrization', style='custom'), Item('elastic_debug@'), Item('double_constraint@'), Spring(resizable=True), label='Configuration parameters', show_border=True, dock='tab', id='ibvpy.mats.matsXD.MATSXD_cmdm.config', ) traits_view = View(Include('polar_fn_group'), dock='tab', id='ibvpy.mats.matsXD.MATSXD_cmdm', kind='modal', resizable=True, scrollable=True, width=0.6, height=0.8, buttons=['OK', 'Cancel']) #------------------------------------------------------------------------- # Setup for computation within a supplied spatial context #------------------------------------------------------------------------- D4_e = Property def _get_D4_e(self): # Return the elasticity tensor return self.elasticity_tensors[0] #------------------------------------------------------------------------- # MICROPLANE-DISCRETIZATION RELATED METHOD #------------------------------------------------------------------------- # get the dyadic product of the microplane normals _MPNN = Property(depends_on='n_mp') @cached_property def _get__MPNN(self): # dyadic product of the microplane normals # return array([outer(mpn, mpn) for mpn in self._MPN]) # old # implementation # n identify the microplane MPNN_nij = einsum('ni,nj->nij', self._MPN, self._MPN) return MPNN_nij # get Third order tangential tensor (operator) for each microplane _MPTT = Property(depends_on='n_mp') @cached_property def _get__MPTT(self): # Third order tangential tensor for each microplane delta = identity(3) MPTT_nijr = 0.5 * ( einsum('ni,jr -> nijr', self._MPN, delta) + einsum('nj,ir -> njir', self._MPN, delta) - 2 * einsum('ni,nj,nr -> nijr', self._MPN, self._MPN, self._MPN)) return MPTT_nijr def _get_e_vct_arr(self, eps_eng): # Projection of apparent strain onto the individual microplanes e_ni = einsum('nj,ji->ni', self._MPN, eps_eng) return e_ni def _get_e_N_arr(self, e_vct_arr): # get the normal strain array for each microplane eN_n = einsum('ni,ni->n', e_vct_arr, self._MPN) return eN_n def _get_e_T_vct_arr(self, e_vct_arr): # get the tangential strain vector array for each microplane eN_n = self._get_e_N_arr(e_vct_arr) eN_vct_ni = einsum('n,ni->ni', eN_n, self._MPN) return e_vct_arr - eN_vct_ni # Alternative methods for the kinematic constraint def _get_e_N_arr_2(self, eps_eng): #eps_mtx = self.map_eps_eng_to_mtx(eps_eng) return einsum('nij,ij->n', self._MPNN, eps_eng) def _get_e_t_vct_arr_2(self, eps_eng): #eps_mtx = self.map_eps_eng_to_mtx(eps_eng) MPTT_ijr = self._get__MPTT() return einsum('nijr,ij->nr', MPTT_ijr, eps_eng) def _get_e_vct_arr_2(self, eps_eng): return self._e_N_arr_2 * self._MPN + self._e_t_vct_arr_2 def _get_I_vol_4(self): # The fourth order volumetric-identity tensor delta = identity(3) I_vol_ijkl = (1.0 / 3.0) * einsum('ij,kl -> ijkl', delta, delta) return I_vol_ijkl def _get_I_dev_4(self): # The fourth order deviatoric-identity tensor delta = identity(3) I_dev_ijkl = 0.5 * (einsum('ik,jl -> ijkl', delta, delta) + einsum('il,jk -> ijkl', delta, delta)) \ - (1 / 3.0) * einsum('ij,kl -> ijkl', delta, delta) return I_dev_ijkl def _get_P_vol(self): delta = identity(3) P_vol_ij = (1 / 3.0) * delta return P_vol_ij def _get_P_dev(self): delta = identity(3) P_dev_njkl = 0.5 * einsum('ni,ij,kl -> njkl', self._MPN, delta, delta) return P_dev_njkl def _get_PP_vol_4(self): # outer product of P_vol delta = identity(3) PP_vol_ijkl = (1 / 9.) * einsum('ij,kl -> ijkl', delta, delta) return PP_vol_ijkl def _get_PP_dev_4(self): # inner product of P_dev delta = identity(3) PP_dev_nijkl = 0.5 * (0.5 * (einsum('ni,nk,jl -> nijkl', self._MPN, self._MPN, delta) + einsum('ni,nl,jk -> nijkl', self._MPN, self._MPN, delta)) + 0.5 * (einsum('ik,nj,nl -> nijkl', delta, self._MPN, self._MPN) + einsum('il,nj,nk -> nijkl', delta, self._MPN, self._MPN))) -\ (1 / 3.) * (einsum('ni,nj,kl -> nijkl', self._MPN, self._MPN, delta) + einsum('ij,nk,nl -> nijkl', delta, self._MPN, self._MPN)) +\ (1 / 9.) * einsum('ij,kl -> ijkl', delta, delta) return PP_dev_nijkl def _get_e_equiv_arr(self, e_vct_arr): ''' Returns a list of the microplane equivalent strains based on the list of microplane strain vectors ''' # magnitude of the normal strain vector for each microplane # @todo: faster numpy functionality possible? e_N_arr = self._get_e_N_arr(e_vct_arr) # positive part of the normal strain magnitude for each microplane e_N_pos_arr = (fabs(e_N_arr) + e_N_arr) / 2 # normal strain vector for each microplane # @todo: faster numpy functionality possible? e_N_vct_arr = einsum('n,ni -> ni', e_N_arr, self._MPN) # tangent strain ratio c_T = self.c_T # tangential strain vector for each microplane e_T_vct_arr = e_vct_arr - e_N_vct_arr # squared tangential strain vector for each microplane e_TT_arr = einsum('ni,ni -> n', e_T_vct_arr, e_T_vct_arr) # equivalent strain for each microplane e_equiv_arr = sqrt(e_N_pos_arr * e_N_pos_arr + c_T * e_TT_arr) return e_equiv_arr def _get_e_max(self, e_equiv_arr, e_max_arr): ''' Compares the equivalent microplane strain of a single microplane with the maximum strain reached in the loading history for the entire array ''' bool_e_max = e_equiv_arr >= e_max_arr # [rch] fixed a bug here - this call was modifying the state array # at any invocation. # # The new value must be created, otherwise side-effect could occur # by writing into a state array. # new_e_max_arr = copy(e_max_arr) new_e_max_arr[bool_e_max] = e_equiv_arr[bool_e_max] return new_e_max_arr def _get_state_variables(self, sctx, eps_app_eng): ''' Compares the list of current equivalent microplane strains with the values in the state array and returns the higher values ''' e_vct_arr = self._get_e_vct_arr(eps_app_eng) e_equiv_arr = self._get_e_equiv_arr(e_vct_arr) #e_max_arr_old = sctx.mats_state_array #e_max_arr_new = self._get_e_max(e_equiv_arr, e_max_arr_old) return e_equiv_arr def _get_phi_arr(self, sctx, eps_app_eng): # Returns a list of the integrity factors for all microplanes. e_max_arr = self._get_state_variables(sctx, eps_app_eng) phi_arr = self._get_phi(e_max_arr, sctx)[0] # print 'phi_arr', phi_arr return phi_arr def _get_phi_mtx(self, sctx, eps_app_eng): # Returns the 2nd order damage tensor 'phi_mtx' # scalar integrity factor for each microplane phi_arr = self._get_phi_arr(sctx, eps_app_eng) # integration terms for each microplanes phi_ij = einsum('n,n,nij->ij', phi_arr, self._MPW, self._MPNN) # print 'phi_ij', phi_ij return phi_ij def _get_d_scalar(self, sctx, eps_app_eng): # scalar integrity factor for each microplane phi_arr = self._get_phi_arr(sctx, eps_app_eng) d_arr = 1.0 - phi_arr d = (1.0 / 3.0) * einsum('n,n->', d_arr, self._MPW) # print d return d def _get_M_vol_tns(self, sctx, eps_app_eng): d = self._get_d_scalar(sctx, eps_app_eng) delta = identity(3) # print d #I_4th_ijkl = einsum('ik,jl -> ijkl', delta, delta) I_4th_ijkl = 0.5 * (einsum('ik,jl -> ijkl', delta, delta) + einsum('il,jk -> ijkl', delta, delta)) return (1 - d) * I_4th_ijkl def _get_M_dev_tns(self, phi_mtx): ''' Returns the 4th order deviatoric damage tensor ''' delta = identity(3) # use numpy functionality (einsum) to evaluate [Jir99], Eq.(21) # M_dev_ijkl = 0.25 * (einsum('ik,jl->ijkl', phi_mtx, delta) + # einsum('il,jk->ijkl', phi_mtx, delta) + # einsum('jk,il->ijkl', phi_mtx, delta) + # einsum('jl,ik->ijkl', phi_mtx, delta)) I_4th_ijkl = 0.5 * (einsum('ik,jl -> ijkl', delta, delta) + einsum('il,jk -> ijkl', delta, delta)) M_dev_ijkl = self.zeta_G * (0.5 * (einsum('ik,jl->ijkl', delta, phi_mtx) + einsum('il,jk->ijkl', delta, phi_mtx)) + 0.5 * (einsum('ik,jl->ijkl', phi_mtx, delta) + einsum('il,jk->ijkl', phi_mtx, delta))) -\ (2.0 * self.zeta_G - 1.0) * I_4th_ijkl * \ (1.0 / 3.0) * trace(phi_mtx) # print 'M_dev_ijkl', M_dev_ijkl return M_dev_ijkl #------------------------------------------------------------------------- # Secant stiffness (irreducible decomposition based on ODFs) #------------------------------------------------------------------------- def _get_S_1_tns(self, sctx, eps_app_eng): #---------------------------------------------------------------------- # Returns the fourth order secant stiffness tensor (eq.1) #---------------------------------------------------------------------- K0 = self.E / (1. - 2. * self.nu) G0 = self.E / (1. + self.nu) e_max_arr = self._get_state_variables(sctx, eps_app_eng) d_n = self._get_phi(e_max_arr, sctx)[1] PP_vol_4 = self._get_PP_vol_4() PP_dev_4 = self._get_PP_dev_4() delta = identity(3) I_4th_ijkl = einsum('ik,jl -> ijkl', delta, delta) I_dev_4 = self._get_I_dev_4() S_1_ijkl = K0 * einsum('n,n,ijkl->ijkl', d_n, self._MPW, PP_vol_4) + \ G0 * 2. * self.zeta_G * einsum('n,n,nijkl->ijkl', d_n, self._MPW, PP_dev_4) - (1. / 3.) * ( 2. * self.zeta_G - 1.) * G0 * einsum('n,n,ijkl->ijkl', d_n, self._MPW, I_dev_4) return S_1_ijkl def _get_S_2_tns(self, sctx, eps_app_eng): #---------------------------------------------------------------------- # Returns the fourth order secant stiffness tensor (eq.2) #---------------------------------------------------------------------- K0 = self.E / (1. - 2. * self.nu) G0 = self.E / (1. + self.nu) I_vol_ijkl = self._get_I_vol_4() I_dev_ijkl = self._get_I_dev_4() phi_mtx = self._get_phi_mtx(sctx, eps_app_eng) M_vol_ijkl = self._get_M_vol_tns(sctx, eps_app_eng) M_dev_ijkl = self._get_M_dev_tns(phi_mtx) S_2_ijkl = K0 * einsum('ijmn,mnrs,rskl -> ijkl', I_vol_ijkl, M_vol_ijkl, I_vol_ijkl) \ + G0 * einsum('ijmn,mnrs,rskl -> ijkl', I_dev_ijkl, M_dev_ijkl, I_dev_ijkl)\ # print 'S_vol = ', K0 * einsum('ijmn,mnrs,rskl -> ijkl', I_vol_ijkl, M_vol_ijkl, I_vol_ijkl) # print 'S_dev = ', G0 * einsum('ijmn,mnrs,rskl -> ijkl', I_dev_ijkl, # M_dev_ijkl, I_dev_ijkl) return S_2_ijkl def _get_S_22_tns(self, sctx, eps_app_eng): #---------------------------------------------------------------------- # Returns the fourth order secant stiffness tensor (double orthotropic) #---------------------------------------------------------------------- K0 = self.E / (1. - 2. * self.nu) G0 = self.E / (1. + self.nu) I_vol_ijkl = self._get_I_vol_4() I_dev_ijkl = self._get_I_dev_4() delta = identity(3) phi_mtx = self._get_phi_mtx(sctx, eps_app_eng) # print 'phi_mtx ', phi_mtx # print '------------------------' D_ij = delta - phi_mtx # print 'D_ij ', D_ij # print '------------------------' d = (1. / 3.) * trace(D_ij) # print 'd ', d # print '------------------------' D_bar_ij = self.zeta_G * (D_ij - d * delta) # print 'D_bar_ij ', D_bar_ij # print '------------------------' # damaged stiffness without simplification S_22_ijkl = (1. - d) * K0 * I_vol_ijkl + (1. - d) * G0 * I_dev_ijkl + (2. / 3.) * (G0) * \ (einsum('ij,kl -> ijkl', delta, D_bar_ij) + einsum('ij,kl -> ijkl', D_bar_ij, delta)) - G0 * \ (0.5 * (einsum('ik,jl -> ijkl', delta, D_bar_ij) + einsum('il,jk -> ijkl', D_bar_ij, delta)) + 0.5 * (einsum('ik,jl -> ijkl', D_bar_ij, delta) + einsum('il,jk -> ijkl', delta, D_bar_ij))) # print einsum('ik,jl -> ijkl', delta, D_bar_ij) #print ((1. - d) * delta - phi_mtx) # print delta M_dev_ijkl = self._get_M_dev_tns(phi_mtx) M_vol_ijkl = self._get_M_vol_tns(sctx, eps_app_eng) S_2_ijkl = K0 * einsum('ijmn,mnrs,rskl -> ijkl', I_vol_ijkl, M_vol_ijkl, I_vol_ijkl) \ + G0 * \ einsum( 'ijmn,mnrs,rskl -> ijkl', I_dev_ijkl, M_dev_ijkl, I_dev_ijkl) # print 'S_22_ijkl', S_22_ijkl # print '------------------------' # print 'S_2_ijkl', S_2_ijkl return S_22_ijkl def _get_S_3_tns(self, sctx, eps_app_eng): #---------------------------------------------------------------------- # Returns the fourth order secant stiffness tensor (eq.3) #---------------------------------------------------------------------- K0 = self.E / (1. - 2. * self.nu) G0 = self.E / (1. + self.nu) I_vol_ijkl = self._get_I_vol_4() I_dev_ijkl = self._get_I_dev_4() S_0_ijkl = K0 * I_vol_ijkl + G0 * I_dev_ijkl e_max_arr = self._get_state_variables(sctx, eps_app_eng) d_n = 1.0 - self._get_phi(e_max_arr, sctx)[0] PP_vol_4 = self._get_PP_vol_4() PP_dev_4 = self._get_PP_dev_4() delta = identity(3) I_4th_ijkl = einsum('ik,jl -> ijkl', delta, delta) # print 'I_4th_ijkl', I_4th_ijkl D_ijkl = einsum('n,n,ijkl->ijkl', d_n, self._MPW, PP_vol_4) + \ einsum('n,n,nijkl->ijkl', d_n, self._MPW, PP_dev_4) # print 'D_ijkl', D_ijkl phi_ijkl = (I_4th_ijkl - D_ijkl) # print 'phi_ijkl', phi_ijkl S_ijkl = einsum('ijmn,mnkl', phi_ijkl, S_0_ijkl) PP_vol_int = einsum('n,abcd->abcd', self._MPW, PP_vol_4) # print 'PP_vol_int', PP_vol_int # print 'I_vol_ijkl', I_vol_ijkl return S_ijkl def _get_S_4_tns(self, sctx, eps_app_eng): #---------------------------------------------------------------------- # Returns the fourth order secant stiffness tensor (double orthotropic) #---------------------------------------------------------------------- K0 = self.E / (1.0 - 2.0 * self.nu) G0 = self.E / (1.0 + self.nu) I_vol_ijkl = self._get_I_vol_4() I_dev_ijkl = self._get_I_dev_4() delta = identity(3) phi_mtx = self._get_phi_mtx(sctx, eps_app_eng) # print 'phi_mtx', phi_mtx D_ij = delta - phi_mtx d = (1.0 / 3.0) * trace(D_ij) print('d_1', d) d = self._get_d_scalar(sctx, eps_app_eng) print('d_2', d) D_bar_ij = self.zeta_G * (D_ij - d * delta) # print 'D_bar_ij', D_bar_ij # print 'D_ij', D_ij S_4_ijkl = (1. - d) * K0 * I_vol_ijkl + (1. - d) * G0 * I_dev_ijkl + (2. / 3.) * (G0 - K0) * \ (einsum('ij,kl -> ijkl', delta, D_bar_ij) + einsum('ij,kl -> ijkl', D_bar_ij, delta)) + 0.5 * (- K0 + 2. * G0) * \ (0.5 * (einsum('ik,jl -> ijkl', delta, D_bar_ij) + einsum('il,jk -> ijkl', D_bar_ij, delta)) + 0.5 * (einsum('ik,jl -> ijkl', D_bar_ij, delta) + einsum('il,jk-> ijkl', delta, D_bar_ij))) # print einsum('ik,jl -> ijkl', delta, D_bar_ij) # print einsum('ik,jl -> ijkl', delta, D_bar_ij) #print ((1. - d) * delta - phi_mtx) # print delta return S_4_ijkl def _get_S_44_tns(self, sctx, eps_app_eng): #---------------------------------------------------------------------- # Returns the fourth order secant stiffness tensor (restrctive orthotropic) #---------------------------------------------------------------------- K0 = self.E / (1. - 2. * self.nu) G0 = self.E / (1. + self.nu) I_vol_ijkl = self._get_I_vol_4() I_dev_ijkl = self._get_I_dev_4() delta = identity(3) phi_mtx = self._get_phi_mtx(sctx, eps_app_eng) D_ij = delta - phi_mtx # damaged stiffness without simplification S_44_ijkl = (1. / 3.) * (K0 - G0) * 0.5 * ((einsum('ij,kl -> ijkl', delta, phi_mtx) + einsum('ij,kl -> ijkl', phi_mtx, delta))) + \ G0 * 0.5 * ((0.5 * (einsum('ik,jl -> ijkl', delta, phi_mtx) + einsum('il,jk -> ijkl', phi_mtx, delta)) + 0.5 * (einsum('ik,jl -> ijkl', phi_mtx, delta) + einsum('il,jk -> ijkl', delta, phi_mtx)))) return S_44_ijkl #------------------------------------------------------------------------- # Evaluation - get the corrector and predictor #------------------------------------------------------------------------- def get_corr_pred(self, sctx, eps_app_eng): ''' Corrector predictor computation. @param eps_app_eng input variable - engineering strain ''' # ----------------------------------------------------------------------------------------------- # for debugging purposes only: if elastic_debug is switched on, linear elastic material is used # ----------------------------------------------------------------------------------------------- if self.elastic_debug: # NOTE: This must be copied otherwise self.D2_e gets modified when # essential boundary conditions are inserted D2_e = copy(self.D2_e) sig_eng = tensordot(D2_e, eps_app_eng, [[1], [0]]) return sig_eng, D2_e #---------------------------------------------------------------------- # if the regularization using the crack-band concept is on calculate the # effective element length in the direction of principle strains #---------------------------------------------------------------------- # if self.regularization: # h = self.get_regularizing_length(sctx, eps_app_eng) # self.phi_fn.h = h #---------------------------------------------------------------------- # Return stresses (corrector) and damaged secant stiffness matrix (predictor) #---------------------------------------------------------------------- eps_e_mtx = eps_app_eng S_ijkl = self._get_S_4_tns(sctx, eps_app_eng)[:] sig_ij = einsum('ijkl,kl -> ij', S_ijkl, eps_e_mtx) return sig_ij
class MATS3DMicroplaneDamageWu(MATSXDMicroplaneDamageFatigueWu): # number of spatial dimensions # n_dim = Constant(3) # number of components of engineering tensor representation # n_eng = Constant(6) #------------------------------------------------------------------------- # PolarDiscr related data #------------------------------------------------------------------------- # # number of microplanes - currently fixed for 3D # n_mp = Constant(28) # get the normal vectors of the microplanes # _MPN = Property(depends_on='n_mp') @cached_property def _get__MPN(self): # microplane normals: return array([[.577350259, .577350259, .577350259], [.577350259, .577350259, -.577350259], [.577350259, -.577350259, .577350259], [.577350259, -.577350259, -.577350259], [.935113132, .250562787, .250562787], [.935113132, .250562787, -.250562787], [.935113132, -.250562787, .250562787], [.935113132, -.250562787, -.250562787], [.250562787, .935113132, .250562787], [.250562787, .935113132, -.250562787], [.250562787, -.935113132, .250562787], [.250562787, -.935113132, -.250562787], [.250562787, .250562787, .935113132], [.250562787, .250562787, -.935113132], [.250562787, -.250562787, .935113132], [.250562787, -.250562787, -.935113132], [.186156720, .694746614, .694746614], [.186156720, .694746614, -.694746614], [.186156720, -.694746614, .694746614], [.186156720, -.694746614, -.694746614], [.694746614, .186156720, .694746614], [.694746614, .186156720, -.694746614], [.694746614, -.186156720, .694746614], [.694746614, -.186156720, -.694746614], [.694746614, .694746614, .186156720], [.694746614, .694746614, -.186156720], [.694746614, -.694746614, .186156720], [.694746614, -.694746614, -.186156720]]) # get the weights of the microplanes # _MPW = Property(depends_on='n_mp') @cached_property def _get__MPW(self): # Note that the values in the array must be multiplied by 6 (cf. [Baz05])! # The sum of of the array equals 0.5. (cf. [BazLuz04])) # The values are given for an Gaussian integration over the unit # hemisphere. return array([ .0160714276, .0160714276, .0160714276, .0160714276, .0204744730, .0204744730, .0204744730, .0204744730, .0204744730, .0204744730, .0204744730, .0204744730, .0204744730, .0204744730, .0204744730, .0204744730, .0158350505, .0158350505, .0158350505, .0158350505, .0158350505, .0158350505, .0158350505, .0158350505, .0158350505, .0158350505, .0158350505, .0158350505 ]) * 6.0 #------------------------------------------------------------------------- # Cached elasticity tensors #------------------------------------------------------------------------- #------------------------------------------------------------------------- # Dock-based view with its own id #------------------------------------------------------------------------- traits_view = View(Include('polar_fn_group'), dock='tab', id='ibvpy.mats.mats3D.mats_3D_cmdm.MATS3D_cmdm', kind='modal', resizable=True, scrollable=True, width=0.6, height=0.8, buttons=['OK', 'Cancel'])
class GaussianMixture1DView(cytoflow.views.HistogramView): """ Attributes ---------- op : Instance(GaussianMixture1DOp) The op whose parameters we're viewing. """ id = 'edu.mit.synbio.cytoflow.view.gaussianmixture1dview' friendly_id = "1D Gaussian Mixture Diagnostic Plot" # TODO - why can't I use GaussianMixture1DOp here? op = Instance(IOperation) channel = DelegatesTo('op') scale = DelegatesTo('op') _by = Property(List) def _get__by(self): facets = filter(lambda x: x, [self.xfacet, self.yfacet]) return list(set(self.op.by) - set(facets)) def enum_plots(self, experiment): """ Returns an iterator over the possible plots that this View can produce. The values returned can be passed to "plot". """ if self.xfacet and self.xfacet not in experiment.conditions: raise util.CytoflowViewError( "X facet {} not in the experiment".format(self.xfacet)) if self.xfacet and self.xfacet not in self.op.by: raise util.CytoflowViewError( "X facet {} must be in GaussianMixture1DOp.by, which is {}". format(self.xfacet, self.op.by)) if self.yfacet and self.yfacet not in experiment.conditions: raise util.CytoflowViewError( "Y facet {0} not in the experiment".format(self.yfacet)) if self.yfacet and self.yfacet not in self.op.by: raise util.CytoflowViewError( "Y facet {} must be in GaussianMixture1DOp.by, which is {}". format(self.yfacet, self.op.by)) for b in self.op.by: if b not in experiment.data: raise util.CytoflowOpError("Aggregation metadata {0} not found" " in the experiment".format(b)) class plot_enum(object): def __init__(self, view, experiment): self._iter = None self._returned = False if view._by: self._iter = experiment.data.groupby(view._by).__iter__() def __iter__(self): return self def next(self): if self._iter: return self._iter.next()[0] else: if self._returned: raise StopIteration else: self._returned = True return None return plot_enum(self, experiment) def plot(self, experiment, plot_name=None, **kwargs): """ Plot the plots. """ if not experiment: raise util.CytoflowViewError("No experiment specified") if not self.op.channel: raise util.CytoflowViewError("No channel specified") experiment = experiment.clone() # try to apply the current operation try: experiment = self.op.apply(experiment) except util.CytoflowOpError: # could have failed because no GMMs have been estimated, or because # op has already been applied pass # if apply() succeeded (or wasn't needed), set up the hue facet if self.op.name and self.op.name in experiment.conditions: if self.huefacet and self.huefacet != self.op.name: warn( "Resetting huefacet to the model component (was {}, now {})." .format(self.huefacet, self.op.name)) self.huefacet = self.op.name if self.subset: try: experiment = experiment.query(self.subset) experiment.data.reset_index(drop=True, inplace=True) except: raise util.CytoflowViewError( "Subset string '{0}' isn't valid".format(self.subset)) if len(experiment) == 0: raise util.CytoflowViewError( "Subset string '{0}' returned no events".format( self.subset)) # figure out common x limits for multiple plots # adjust the limits to clip extreme values min_quantile = kwargs.pop("min_quantile", 0.001) max_quantile = kwargs.pop("max_quantile", 0.999) xlim = kwargs.pop("xlim", None) if xlim is None: xlim = (experiment.data[self.op.channel].quantile(min_quantile), experiment.data[self.op.channel].quantile(max_quantile)) # see if we're making subplots if self._by and plot_name is None: for plot in self.enum_plots(experiment): self.plot(experiment, plot, xlim=xlim, **kwargs) plt.title("{0} = {1}".format(self.op.by, plot)) return if plot_name is not None: if plot_name is not None and not self._by: raise util.CytoflowViewError( "Plot {} not from plot_enum".format(plot_name)) groupby = experiment.data.groupby(self._by) if plot_name not in set(groupby.groups.keys()): raise util.CytoflowViewError( "Plot {} not from plot_enum".format(plot_name)) experiment.data = groupby.get_group(plot_name) experiment.data.reset_index(drop=True, inplace=True) # get the parameterized scale object back from the op scale = self.op._scale # plot the histogram, whether or not we're plotting distributions on top g = super(GaussianMixture1DView, self).plot(experiment, scale=scale, xlim=xlim, **kwargs) # plot the actual distribution on top of it. row_names = g.row_names if g.row_names else [False] col_names = g.col_names if g.col_names else [False] for (i, row), (j, col) in product(enumerate(row_names), enumerate(col_names)): facets = filter(lambda x: x, [row, col]) if plot_name is not None: try: gmm_name = tuple(list(plot_name) + facets) except TypeError: # plot_name isn't a list gmm_name = tuple(list([plot_name]) + facets) else: gmm_name = tuple(facets) if len(gmm_name) == 1: gmm_name = gmm_name[0] if gmm_name: if gmm_name in self.op._gmms: gmm = self.op._gmms[gmm_name] else: # there weren't any events in this subset to estimate a GMM from warn("No estimated GMM for plot {}".format(gmm_name), util.CytoflowViewWarning) return g else: if True in self.op._gmms: gmm = self.op._gmms[True] else: return g ax = g.facet_axis(i, j) for k in range(0, len(gmm.means_)): # we want to scale the plots so they have the same area under the # curve as the histograms. it used to be that we got the area from # repeating the assignments, then calculating bin widths, etc. but # really, if we just plotted the damn thing already, we can get the # area of the plot from the Polygon patch that we just plotted! patch = ax.patches[k] xy = patch.get_xy() pdf_scale = poly_area([scale(p[0]) for p in xy], [p[1] for p in xy]) # cheat a little # pdf_scale *= 1.1 plt_min, plt_max = plt.gca().get_xlim() x = scale.inverse( np.linspace(scale(plt_min), scale(plt_max), 500)) mean = gmm.means_[k][0] stdev = np.sqrt(gmm.covariances_[k][0]) y = stats.norm.pdf(scale(x), mean, stdev) * pdf_scale color_k = k % len(sns.color_palette()) color = sns.color_palette()[color_k] ax.plot(x, y, color=color) return g
class Slider(Component): """ A horizontal or vertical slider bar """ #------------------------------------------------------------------------ # Model traits #------------------------------------------------------------------------ min = Float() max = Float() value = Float() # The number of ticks to show on the slider. num_ticks = Int(4) #------------------------------------------------------------------------ # Bar and endcap appearance #------------------------------------------------------------------------ # Whether this is a horizontal or vertical slider orientation = Enum("h", "v") # The thickness, in pixels, of the lines used to render the ticks, # endcaps, and main slider bar. bar_width = Int(4) bar_color = ColorTrait("black") # Whether or not to render endcaps on the slider bar endcaps = Bool(True) # The extent of the endcaps, in pixels. This is a read-only property, # since the endcap size can be set as either a fixed number of pixels or # a percentage of the widget's size in the transverse direction. endcap_size = Property # The extent of the tickmarks, in pixels. This is a read-only property, # since the endcap size can be set as either a fixed number of pixels or # a percentage of the widget's size in the transverse direction. tick_size = Property #------------------------------------------------------------------------ # Slider appearance #------------------------------------------------------------------------ # The kind of marker to use for the slider. slider = SliderMarkerTrait("rect") # If the slider marker is "rect", this is the thickness of the slider, # i.e. its extent in the dimension parallel to the long axis of the widget. # For other slider markers, this has no effect. slider_thickness = Int(9) # The size of the slider, in pixels. This is a read-only property, since # the slider size can be set as either a fixed number of pixels or a # percentage of the widget's size in the transverse direction. slider_size = Property # For slider markers with a filled area, this is the color of the filled # area. For slider markers that are just lines/strokes (e.g. cross, plus), # this is the color of the stroke. slider_color = ColorTrait("red") # For slider markers with a filled area, this is the color of the outline # border drawn around the filled area. For slider markers that have just # lines/strokes, this has no effect. slider_border = ColorTrait("none") # For slider markers with a filled area, this is the width, in pixels, # of the outline around the area. For slider markers that are just lines/ # strokes, this is the thickness of the stroke. slider_outline_width = Int(1) # The kiva.CompiledPath representing the custom path to render for the # slider, if the **slider** trait is set to "custom". custom_slider = Any() #------------------------------------------------------------------------ # Interaction traits #------------------------------------------------------------------------ # Can this slider be interacted with, or is it just a display interactive = Bool(True) mouse_button = Enum("left", "right") event_state = Enum("normal", "dragging") #------------------------------------------------------------------------ # Private traits #------------------------------------------------------------------------ # Returns the coordinate index (0 or 1) corresponding to our orientation. # Used internally; read-only property. axis_ndx = Property() _slider_size_mode = Enum("fixed", "percent") _slider_percent = Float(0.0) _cached_slider_size = Int(10) _endcap_size_mode = Enum("fixed", "percent") _endcap_percent = Float(0.0) _cached_endcap_size = Int(20) _tick_size_mode = Enum("fixed", "percent") _tick_size_percent = Float(0.0) _cached_tick_size = Int(20) # A tuple of (dx, dy) of the difference between the mouse position and # center of the slider. _offset = Any((0, 0)) def set_range(self, min, max): self.min = min self.max = max def map_screen(self, val): """ Returns an (x,y) coordinate corresponding to the location of **val** on the slider. """ # Some local variables to handle orientation dependence axis_ndx = self.axis_ndx other_ndx = 1 - axis_ndx screen_low = self.position[axis_ndx] screen_high = screen_low + self.bounds[axis_ndx] # The return coordinate. The return value along the non-primary # axis will be the same in all cases. coord = [0, 0] coord[ other_ndx] = self.position[other_ndx] + self.bounds[other_ndx] / 2 # Handle exceptional/boundary cases if val <= self.min: coord[axis_ndx] = screen_low return coord elif val >= self.max: coord[axis_ndx] = screen_high return coord elif self.min == self.max: coord[axis_ndx] = (screen_low + screen_high) / 2 return coord # Handle normal cases coord[axis_ndx] = (val - self.min) / ( self.max - self.min) * self.bounds[axis_ndx] + screen_low return coord def map_data(self, x, y, clip=True): """ Returns a value between min and max that corresponds to the given x and y values. Parameters ========== x, y : Float The screen coordinates to map clip : Bool (default=True) Whether points outside the range should be clipped to the max or min value of the slider (depending on which it's closer to) Returns ======= value : Float """ # Some local variables to handle orientation dependence axis_ndx = self.axis_ndx other_ndx = 1 - axis_ndx screen_low = self.position[axis_ndx] screen_high = screen_low + self.bounds[axis_ndx] if self.orientation == "h": coord = x else: coord = y # Handle exceptional/boundary cases if coord >= screen_high: return self.max elif coord <= screen_low: return self.min elif screen_high == screen_low: return (self.max + self.min) / 2 # Handle normal cases return (coord - screen_low) /self.bounds[axis_ndx] * \ (self.max - self.min) + self.min def set_slider_pixels(self, pixels): """ Sets the width of the slider to be a fixed number of pixels Parameters ========== pixels : int The number of pixels wide that the slider should be """ self._slider_size_mode = "fixed" self._cached_slider_size = pixels def set_slider_percent(self, percent): """ Sets the width of the slider to be a percentage of the width of the slider widget. Parameters ========== percent : float The percentage, between 0.0 and 1.0 """ self._slider_size_mode = "percent" self._slider_percent = percent self._update_sizes() def set_endcap_pixels(self, pixels): """ Sets the width of the endcap to be a fixed number of pixels Parameters ========== pixels : int The number of pixels wide that the endcap should be """ self._endcap_size_mode = "fixed" self._cached_endcap_size = pixels def set_endcap_percent(self, percent): """ Sets the width of the endcap to be a percentage of the width of the endcap widget. Parameters ========== percent : float The percentage, between 0.0 and 1.0 """ self._endcap_size_mode = "percent" self._endcap_percent = percent self._update_sizes() def set_tick_pixels(self, pixels): """ Sets the width of the tick marks to be a fixed number of pixels Parameters ========== pixels : int The number of pixels wide that the endcap should be """ self._tick_size_mode = "fixed" self._cached_tick_size = pixels def set_tick_percent(self, percent): """ Sets the width of the tick marks to be a percentage of the width of the endcap widget. Parameters ========== percent : float The percentage, between 0.0 and 1.0 """ self._tick_size_mode = "percent" self._tick_percent = percent self._update_sizes() #------------------------------------------------------------------------ # Rendering methods #------------------------------------------------------------------------ def _draw_mainlayer(self, gc, view_bounds=None, mode="normal"): start = [0, 0] end = [0, 0] axis_ndx = self.axis_ndx other_ndx = 1 - axis_ndx bar_x = self.x + self.width / 2 bar_y = self.y + self.height / 2 # Draw the bar and endcaps gc.set_stroke_color(self.bar_color_) gc.set_line_width(self.bar_width) if self.orientation == "h": gc.move_to(self.x, bar_y) gc.line_to(self.x2, bar_y) gc.stroke_path() if self.endcaps: start_y = bar_y - self._cached_endcap_size / 2 end_y = bar_y + self._cached_endcap_size / 2 gc.move_to(self.x, start_y) gc.line_to(self.x, end_y) gc.move_to(self.x2, start_y) gc.line_to(self.x2, end_y) if self.num_ticks > 0: x_pts = linspace(self.x, self.x2, self.num_ticks + 2).astype(int) starts = zeros((len(x_pts), 2), dtype=int) starts[:, 0] = x_pts starts[:, 1] = bar_y - self._cached_tick_size / 2 ends = starts.copy() ends[:, 1] = bar_y + self._cached_tick_size / 2 gc.line_set(starts, ends) else: gc.move_to(bar_x, self.y) gc.line_to(bar_x, self.y2) if self.endcaps: start_x = bar_x - self._cached_endcap_size / 2 end_x = bar_x + self._cached_endcap_size / 2 gc.move_to(start_x, self.y) gc.line_to(end_x, self.y) gc.move_to(start_x, self.y2) gc.line_to(end_x, self.y2) if self.num_ticks > 0: y_pts = linspace(self.y, self.y2, self.num_ticks + 2).astype(int) starts = zeros((len(y_pts), 2), dtype=int) starts[:, 1] = y_pts starts[:, 0] = bar_x - self._cached_tick_size / 2 ends = starts.copy() ends[:, 0] = bar_x + self._cached_tick_size / 2 gc.line_set(starts, ends) gc.stroke_path() # Draw the slider pt = self.map_screen(self.value) if self.slider == "rect": gc.set_fill_color(self.slider_color_) gc.set_stroke_color(self.slider_border_) gc.set_line_width(self.slider_outline_width) rect = self._get_rect_slider_bounds() gc.rect(*rect) gc.draw_path() else: self._render_marker(gc, pt, self._cached_slider_size, self.slider_(), self.custom_slider) def _get_rect_slider_bounds(self): """ Returns the (x, y, w, h) bounds of the rectangle representing the slider. Used for rendering and hit detection. """ bar_x = self.x + self.width / 2 bar_y = self.y + self.height / 2 pt = self.map_screen(self.value) if self.orientation == "h": slider_height = self._cached_slider_size return (pt[0] - self.slider_thickness, bar_y - slider_height / 2, self.slider_thickness, slider_height) else: slider_width = self._cached_slider_size return (bar_x - slider_width / 2, pt[1] - self.slider_thickness, slider_width, self.slider_thickness) def _render_marker(self, gc, point, size, marker, custom_path): with gc: gc.begin_path() if marker.draw_mode == STROKE: gc.set_stroke_color(self.slider_color_) gc.set_line_width(self.slider_thickness) else: gc.set_fill_color(self.slider_color_) gc.set_stroke_color(self.slider_border_) gc.set_line_width(self.slider_outline_width) if hasattr(gc, "draw_marker_at_points") and \ (marker.__class__ != CustomMarker) and \ (gc.draw_marker_at_points([point], size, marker.kiva_marker) != 0): pass elif hasattr(gc, "draw_path_at_points"): if marker.__class__ != CustomMarker: path = gc.get_empty_path() marker.add_to_path(path, size) mode = marker.draw_mode else: path = custom_path mode = STROKE if not marker.antialias: gc.set_antialias(False) gc.draw_path_at_points([point], path, mode) else: if not marker.antialias: gc.set_antialias(False) if marker.__class__ != CustomMarker: gc.translate_ctm(*point) # Kiva GCs have a path-drawing interface marker.add_to_path(gc, size) gc.draw_path(marker.draw_mode) else: path = custom_path gc.translate_ctm(*point) gc.add_path(path) gc.draw_path(STROKE) #------------------------------------------------------------------------ # Interaction event handlers #------------------------------------------------------------------------ def normal_left_down(self, event): if self.mouse_button == "left": return self._mouse_pressed(event) def dragging_left_up(self, event): if self.mouse_button == "left": return self._mouse_released(event) def normal_right_down(self, event): if self.mouse_button == "right": return self._mouse_pressed(event) def dragging_right_up(self, event): if self.mouse_button == "right": return self._mouse_released(event) def dragging_mouse_move(self, event): dx, dy = self._offset self.value = self.map_data(event.x - dx, event.y - dy) event.handled = True self.request_redraw() def dragging_mouse_leave(self, event): self.event_state = "normal" def _mouse_pressed(self, event): # Determine the slider bounds so we can hit test it pt = self.map_screen(self.value) if self.slider == "rect": x, y, w, h = self._get_rect_slider_bounds() x2 = x + w y2 = y + h else: x, y = pt size = self._cached_slider_size x -= size / 2 y -= size / 2 x2 = x + size y2 = y + size # Hit test both the slider and against the bar. If the user has # clicked on the bar but outside of the slider, we set the _offset # and call dragging_mouse_move() to teleport the slider to the # mouse click position. if self.orientation == "v" and (x <= event.x <= x2): if not (y <= event.y <= y2): self._offset = (event.x - pt[0], 0) self.dragging_mouse_move(event) else: self._offset = (event.x - pt[0], event.y - pt[1]) elif self.orientation == "h" and (y <= event.y <= y2): if not (x <= event.x <= x2): self._offset = (0, event.y - pt[1]) self.dragging_mouse_move(event) else: self._offset = (event.x - pt[0], event.y - pt[1]) else: # The mouse click missed the bar and the slider. return event.handled = True self.event_state = "dragging" return def _mouse_released(self, event): self.event_state = "normal" event.handled = True #------------------------------------------------------------------------ # Private trait event handlers and property getters/setters #------------------------------------------------------------------------ def _get_axis_ndx(self): if self.orientation == "h": return 0 else: return 1 def _get_slider_size(self): return self._cached_slider_size def _get_endcap_size(self): return self._cached_endcap_size def _get_tick_size(self): return self._cached_tick_size @on_trait_change("bounds,bounds_items") def _update_sizes(self): if self._slider_size_mode == "percent": if self.orientation == "h": self._cached_slider_size = int(self.height * self._slider_percent) else: self._cached_slider_size = int(self.width * self._slider_percent) if self._endcap_size_mode == "percent": if self.orientation == "h": self._cached_endcap_size = int(self.height * self._endcap_percent) else: self._cached_endcap_size = int(self.width * self._endcap_percent) return
class Loop(HasTraits): """ A current loop class. """ #------------------------------------------------------------------------- # Public traits #------------------------------------------------------------------------- direction = Array(float, value=(0, 0, 1), cols=3, shape=(3, ), desc='directing vector of the loop', enter_set=True, auto_set=False) radius = CFloat(0.1, desc='radius of the loop', enter_set=True, auto_set=False) position = Array(float, value=(0, 0, 0), cols=3, shape=(3, ), desc='position of the center of the loop', enter_set=True, auto_set=False) _plot = None Bnorm = Property(depends_on='direction,position,radius') view = View('position', 'direction', 'radius', '_') #------------------------------------------------------------------------- # Loop interface #------------------------------------------------------------------------- def base_vectors(self): """ Returns 3 orthognal base vectors, the first one colinear to the axis of the loop. """ # normalize n n = self.direction / (self.direction**2).sum(axis=-1) # choose two vectors perpendicular to n # choice is arbitrary since the coil is symetric about n if np.abs(n[0]) == 1: l = np.r_[n[2], 0, -n[0]] else: l = np.r_[0, n[2], -n[1]] l /= (l**2).sum(axis=-1) m = np.cross(n, l) return n, l, m @on_trait_change('Bnorm') def redraw(self): if hasattr(self, 'app') and self.app.scene._renderer is not None: self.display() self.app.visualize_field() def display(self): """ Display the coil in the 3D view. """ n, l, m = self.base_vectors() theta = np.linspace(0, 2 * np.pi, 30)[..., np.newaxis] coil = self.radius * (np.sin(theta) * l + np.cos(theta) * m) coil += self.position coil_x, coil_y, coil_z = coil.T if self._plot is None: self._plot = self.app.scene.mlab.plot3d(coil_x, coil_y, coil_z, tube_radius=0.007, color=(0, 0, 1), name='Coil') else: self._plot.mlab_source.trait_set(x=coil_x, y=coil_y, z=coil_z) def _get_Bnorm(self): """ returns the magnetic field for the current loop calculated from eqns (1) and (2) in Phys Rev A Vol. 35, N 4, pp. 1535-1546; 1987. """ ### Translate the coordinates in the coil's frame n, l, m = self.base_vectors() R = self.radius r0 = self.position r = np.c_[np.ravel(X), np.ravel(Y), np.ravel(Z)] # transformation matrix coil frame to lab frame trans = np.vstack((l, m, n)) r -= r0 #point location from center of coil r = np.dot(r, linalg.inv(trans)) #transform vector to coil frame #### calculate field # express the coordinates in polar form x = r[:, 0] y = r[:, 1] z = r[:, 2] rho = np.sqrt(x**2 + y**2) theta = np.arctan2(x, y) E = special.ellipe((4 * R * rho) / ((R + rho)**2 + z**2)) K = special.ellipk((4 * R * rho) / ((R + rho)**2 + z**2)) Bz = 1 / np.sqrt((R + rho)**2 + z**2) * (K + E * (R**2 - rho**2 - z**2) / ((R - rho)**2 + z**2)) Brho = z / (rho * np.sqrt((R + rho)**2 + z**2)) * (-K + E * (R**2 + rho**2 + z**2) / ((R - rho)**2 + z**2)) # On the axis of the coil we get a divided by zero here. This returns a # NaN, where the field is actually zero : Brho[np.isnan(Brho)] = 0 B = np.c_[np.cos(theta) * Brho, np.sin(theta) * Brho, Bz] # Rotate the field back in the lab's frame B = np.dot(B, trans) Bx, By, Bz = B.T Bx = np.reshape(Bx, X.shape) By = np.reshape(By, X.shape) Bz = np.reshape(Bz, X.shape) Bnorm = np.sqrt(Bx**2 + By**2 + Bz**2) # We need to threshold ourselves, rather than with VTK, to be able # to use an ImageData Bmax = 10 * np.median(Bnorm) Bx[Bnorm > Bmax] = np.NAN By[Bnorm > Bmax] = np.NAN Bz[Bnorm > Bmax] = np.NAN Bnorm[Bnorm > Bmax] = np.NAN self.Bx = Bx self.By = By self.Bz = Bz return Bnorm
class ThermoSource(BaseSource): trap_voltage = Property(depends_on='_trap_voltage') _trap_voltage = Float trap_current = Property(depends_on='_trap_current') _trap_current = Float z_symmetry = Property(depends_on='_z_symmetry') y_symmetry = Property(depends_on='_y_symmetry') extraction_lens = Property(Range(0, 100.0), depends_on='_extraction_lens') emission = Float _y_symmetry = Float # Range(0.0, 100.) _z_symmetry = Float # Range(0.0, 100.) y_symmetry_low = Float(-100.0) y_symmetry_high = Float(100.0) z_symmetry_low = Float(-100.0) z_symmetry_high = Float(100.0) _extraction_lens = Float # Range(0.0, 100.) def set_hv(self, v): return self._set_value('SetHV', v) def read_emission(self): return self._read_value('GetParameter Source Current Readback', 'emission') def read_trap_current(self): return self._read_value('GetParameter Trap Current Readback', '_trap_current') def read_y_symmetry(self): return self._read_value('GetYSymmetry', '_y_symmetry') def read_z_symmetry(self): return self._read_value('GetZSymmetry', '_z_symmetry') def read_trap_voltage(self): return self._read_value('GetParameter Trap Voltage Readback', '_trap_voltage') def read_hv(self): return self._read_value('GetHighVoltage', 'current_hv') def _set_value(self, name, v): r = self.ask('{} {}'.format(name, v)) if r is not None: if r.lower().strip() == 'ok': return True def _read_value(self, name, value): r = self.ask(name) try: r = float('{:0.3f}'.format(float(r))) setattr(self, value, r) return getattr(self, value) except (ValueError, TypeError): pass def sync_parameters(self): self.read_y_symmetry() self.read_z_symmetry() self.read_trap_current() self.read_hv() def traits_view(self): v = View( Item('nominal_hv', format_str='%0.4f'), Item('current_hv', format_str='%0.4f', style='readonly'), Item('trap_current'), Item('trap_voltage'), Item('y_symmetry', editor=RangeEditor(low_name='y_symmetry_low', high_name='y_symmetry_high', mode='slider')), Item('z_symmetry', editor=RangeEditor(low_name='z_symmetry_low', high_name='z_symmetry_high', mode='slider')), Item('extraction_lens')) return v # =============================================================================== # property get/set # =============================================================================== def _get_trap_voltage(self): return self._trap_voltage def _get_trap_current(self): return self._trap_current def _get_y_symmetry(self): return self._y_symmetry def _get_z_symmetry(self): return self._z_symmetry def _get_extraction_lens(self): return self._extraction_lens def _set_trap_voltage(self, v): if self._set_value('SetParameter', 'Trap Voltage Set,{}'.format(v)): self._trap_current = v def _set_trap_current(self, v): if self._set_value('SetParameter', 'Trap Current Set,{}'.format(v)): self._trap_current = v def _set_y_symmetry(self, v): if self._set_value('SetYSymmetry', v): self._y_symmetry = v def _set_z_symmetry(self, v): if self._set_value('SetZSymmetry', v): self._z_symmetry = v def _set_extraction_lens(self, v): if self._set_value('SetExtractionLens', v): self._extraction_lens = v
class SimpleEditor(Editor): """ Simple style of editor for sets. The editor displays two list boxes, with buttons for moving the selected items from left to right, or vice versa. If **can_move_all** on the factory is True, then buttons are displayed for moving all the items to one box or the other. If the set is ordered, buttons are displayed for moving the selected item up or down in right-side list box. """ # ------------------------------------------------------------------------- # Trait definitions: # ------------------------------------------------------------------------- #: Current set of enumeration names: names = Property() #: Current mapping from names to values: mapping = Property() #: Current inverse mapping from values to names: inverse_mapping = Property() #: Is set editor scrollable? This value overrides the default. scrollable = True def init(self, parent): """ Finishes initializing the editor by creating the underlying toolkit widget. """ factory = self.factory if factory.name != "": self._object, self._name, self._value = self.parse_extended_name( factory.name ) self.values_changed() self._object.on_trait_change( self._values_changed, self._name, dispatch="ui" ) else: self._value = lambda: self.factory.values self.values_changed() factory.on_trait_change( self._values_changed, "values", dispatch="ui" ) self.control = panel = TraitsUIPanel(parent, -1) hsizer = wx.BoxSizer(wx.HORIZONTAL) vsizer = wx.BoxSizer(wx.VERTICAL) self._unused = self._create_listbox( panel, hsizer, self._on_unused, self._on_use, factory.left_column_title, ) self._use_all = self._unuse_all = self._up = self._down = None if factory.can_move_all: self._use_all = self._create_button( ">>", panel, vsizer, 15, self._on_use_all ) self._use = self._create_button(">", panel, vsizer, 15, self._on_use) self._unuse = self._create_button( "<", panel, vsizer, 0, self._on_unuse ) if factory.can_move_all: self._unuse_all = self._create_button( "<<", panel, vsizer, 15, self._on_unuse_all ) if factory.ordered: self._up = self._create_button( "Move Up", panel, vsizer, 30, self._on_up ) self._down = self._create_button( "Move Down", panel, vsizer, 0, self._on_down ) hsizer.Add(vsizer, 0, wx.LEFT | wx.RIGHT, 8) self._used = self._create_listbox( panel, hsizer, self._on_value, self._on_unuse, factory.right_column_title, ) panel.SetSizer(hsizer) self.context_object.on_trait_change( self.update_editor, self.extended_name + "_items?", dispatch="ui" ) self.set_tooltip() def _get_names(self): """ Gets the current set of enumeration names. """ return self._names def _get_mapping(self): """ Gets the current mapping. """ return self._mapping def _get_inverse_mapping(self): """ Gets the current inverse mapping. """ return self._inverse_mapping def _create_listbox(self, parent, sizer, handler1, handler2, title): """ Creates a list box. """ column_sizer = wx.BoxSizer(wx.VERTICAL) # Add the column title in emphasized text: title_widget = wx.StaticText(parent, -1, title) font = title_widget.GetFont() emphasis_font = wx.Font( font.GetPointSize() + 1, font.GetFamily(), font.GetStyle(), wx.BOLD ) title_widget.SetFont(emphasis_font) column_sizer.Add(title_widget, 0, 0) # Create the list box and add it to the column: list = wx.ListBox(parent, -1, style=wx.LB_EXTENDED | wx.LB_NEEDED_SB) column_sizer.Add(list, 1, wx.EXPAND) # Add the column to the SetEditor widget: sizer.Add(column_sizer, 1, wx.EXPAND) # Hook up the event handlers: parent.Bind(wx.EVT_LISTBOX, handler1, id=list.GetId()) parent.Bind(wx.EVT_LISTBOX_DCLICK, handler2, id=list.GetId()) return list def _create_button(self, label, parent, sizer, space_before, handler): """ Creates a button. """ button = wx.Button(parent, -1, label, style=wx.BU_EXACTFIT) sizer.AddSpacer(space_before) sizer.Add(button, 0, wx.EXPAND | wx.BOTTOM, 8) parent.Bind(wx.EVT_BUTTON, handler, id=button.GetId()) return button def values_changed(self): """ Recomputes the cached data based on the underlying enumeration model or the values of the factory. """ self._names, self._mapping, self._inverse_mapping = enum_values_changed( self._value(), self.string_value ) def _values_changed(self): """ Handles the underlying object model's enumeration set or factory's values being changed. """ self.values_changed() self.update_editor() def update_editor(self): """ Updates the editor when the object trait changes externally to the editor. """ # Check for any items having been deleted from the enumeration that are # still present in the object value: mapping = self.inverse_mapping.copy() values = [v for v in self.value if v in mapping] if len(values) < len(self.value): self.value = values return # Get a list of the selected items in the right box: used = self._used used_labels = self._get_selected_strings(used) # Get a list of the selected items in the left box: unused = self._unused unused_labels = self._get_selected_strings(unused) # Empty list boxes in preparation for rebuilding from current values: used.Clear() unused.Clear() # Ensure right list box is kept alphabetized unless insertion # order is relevant: if not self.factory.ordered: values = sorted(values[:]) # Rebuild the right listbox: used_selections = [] for i, value in enumerate(values): label = mapping[value] used.Append(label) del mapping[value] if label in used_labels: used_selections.append(i) # Rebuild the left listbox: unused_selections = [] unused_items = sorted(mapping.values()) mapping = self.mapping self._unused_items = [mapping[ui] for ui in unused_items] for i, unused_item in enumerate(unused_items): unused.Append(unused_item) if unused_item in unused_labels: unused_selections.append(i) # If nothing is selected, default selection should be top of left box, # or of right box if left box is empty: if (len(used_selections) == 0) and (len(unused_selections) == 0): if unused.GetCount() == 0: used_selections.append(0) else: unused_selections.append(0) used_count = used.GetCount() for i in used_selections: if i < used_count: used.SetSelection(i) unused_count = unused.GetCount() for i in unused_selections: if i < unused_count: unused.SetSelection(i) self._check_up_down() self._check_left_right() def dispose(self): """ Disposes of the contents of an editor. """ if self._object is not None: self._object.on_trait_change( self._values_changed, self._name, remove=True ) else: self.factory.on_trait_change( self._values_changed, "values", remove=True ) self.context_object.on_trait_change( self.update_editor, self.extended_name + "_items?", remove=True ) super(SimpleEditor, self).dispose() def get_error_control(self): """ Returns the editor's control for indicating error status. """ return [self._unused, self._used] # ------------------------------------------------------------------------- # Event handlers: # ------------------------------------------------------------------------- def _on_value(self, event): if not self.factory.ordered: self._clear_selection(self._unused) self._check_left_right() self._check_up_down() def _on_unused(self, event): if not self.factory.ordered: self._clear_selection(self._used) self._check_left_right() self._check_up_down() def _on_use(self, event): self._unused_items, self.value = self._transfer_items( self._unused, self._used, self._unused_items, self.value ) def _on_unuse(self, event): self.value, self._unused_items = self._transfer_items( self._used, self._unused, self.value, self._unused_items ) def _on_use_all(self, event): self._unused_items, self.value = self._transfer_all( self._unused, self._used, self._unused_items, self.value ) def _on_unuse_all(self, event): self.value, self._unused_items = self._transfer_all( self._used, self._unused, self.value, self._unused_items ) def _on_up(self, event): self._move_item(-1) def _on_down(self, event): self._move_item(1) # ------------------------------------------------------------------------- # Private methods: # ------------------------------------------------------------------------- # ------------------------------------------------------------------------- # Unselects all items in the given ListBox # ------------------------------------------------------------------------- def _clear_selection(self, box): """ Unselects all items in the given ListBox """ for i in box.GetSelections(): box.Deselect(i) def _transfer_all(self, list_from, list_to, values_from, values_to): """ Transfers all items from one list to another. """ values_from = values_from[:] values_to = values_to[:] self._clear_selection(list_from) while list_from.GetCount() > 0: index_to = list_to.GetCount() list_from.SetSelection(0) list_to.InsertItems( self._get_selected_strings(list_from), index_to ) list_from.Delete(0) values_to.append(values_from[0]) del values_from[0] list_to.SetSelection(0) self._check_left_right() self._check_up_down() return (values_from, values_to) def _transfer_items(self, list_from, list_to, values_from, values_to): """ Transfers the selected item from one list to another. """ values_from = values_from[:] values_to = values_to[:] indices_from = list_from.GetSelections() index_from = max(self._get_first_selection(list_from), 0) index_to = max(self._get_first_selection(list_to), 0) self._clear_selection(list_to) # Get the list of strings in the "from" box to be moved: selected_list = self._get_selected_strings(list_from) # fixme: I don't know why I have to reverse the list to get # correct behavior from the ordered list box. Investigate -- LP selected_list.reverse() list_to.InsertItems(selected_list, index_to) # Delete the transferred items from the left box: for i in range(len(indices_from) - 1, -1, -1): list_from.Delete(indices_from[i]) # Delete the transferred items from the "unused" value list: for item_label in selected_list: val_index_from = values_from.index(self.mapping[item_label]) values_to.insert(index_to, values_from[val_index_from]) del values_from[val_index_from] # If right list is ordered, keep moved items selected: if self.factory.ordered: list_to.SetSelection(list_to.FindString(item_label)) # Reset the selection in the left box: count = list_from.GetCount() if count > 0: if index_from >= count: index_from -= 1 list_from.SetSelection(index_from) self._check_left_right() self._check_up_down() return (values_from, values_to) def _move_item(self, direction): """ Moves an item up or down within the "used" list. """ # Move the item up/down within the list: listbox = self._used index_from = self._get_first_selection(listbox) index_to = index_from + direction label = listbox.GetString(index_from) listbox.Deselect(index_from) listbox.Delete(index_from) listbox.Insert(label, index_to) listbox.SetSelection(index_to) # Enable the up/down buttons appropriately: self._check_up_down() # Move the item up/down within the editor's trait value: value = self.value if direction < 0: index = index_to values = [value[index_from], value[index_to]] else: index = index_from values = [value[index_to], value[index_from]] self.value = value[:index] + values + value[index + 2 :] def _check_up_down(self): """ Sets the proper enabled state for the up and down buttons. """ if self.factory.ordered: index_selected = self._used.GetSelections() self._up.Enable( (len(index_selected) == 1) and (index_selected[0] > 0) ) self._down.Enable( (len(index_selected) == 1) and (index_selected[0] < (self._used.GetCount() - 1)) ) def _check_left_right(self): """ Sets the proper enabled state for the left and right buttons. """ self._use.Enable( self._unused.GetCount() > 0 and self._get_first_selection(self._unused) >= 0 ) self._unuse.Enable( self._used.GetCount() > 0 and self._get_first_selection(self._used) >= 0 ) if self.factory.can_move_all: self._use_all.Enable( (self._unused.GetCount() > 0) and (self._get_first_selection(self._unused) >= 0) ) self._unuse_all.Enable( (self._used.GetCount() > 0) and (self._get_first_selection(self._used) >= 0) ) # ------------------------------------------------------------------------- # Returns a list of the selected strings in the listbox # ------------------------------------------------------------------------- def _get_selected_strings(self, listbox): """ Returns a list of the selected strings in the given *listbox*. """ stringlist = [] for label_index in listbox.GetSelections(): stringlist.append(listbox.GetString(label_index)) return stringlist # ------------------------------------------------------------------------- # Returns the index of the first (or only) selected item. # ------------------------------------------------------------------------- def _get_first_selection(self, listbox): """ Returns the index of the first (or only) selected item. """ select_list = listbox.GetSelections() if len(select_list) == 0: return -1 return select_list[0]
class BaseEditor(Editor): """ Base class for enumeration editors. """ #: Current set of enumeration names: names = Property() #: Current mapping from names to values: mapping = Property() #: Current inverse mapping from values to names: inverse_mapping = Property() # ------------------------------------------------------------------------- # BaseEditor Interface # ------------------------------------------------------------------------- def values_changed(self): """ Recomputes the cached data based on the underlying enumeration model or the values of the factory. """ self._names, self._mapping, self._inverse_mapping = enum_values_changed( self._value(), self.string_value) def rebuild_editor(self): """ Rebuilds the contents of the editor whenever the original factory object's **values** trait changes. This is not needed for the Qt backends. """ raise NotImplementedError # ------------------------------------------------------------------------- # Editor Interface # ------------------------------------------------------------------------- def init(self, parent): """ Finishes initializing the editor by creating the underlying toolkit widget. """ factory = self.factory if factory.name != "": self._object, self._name, self._value = self.parse_extended_name( factory.name) self.values_changed() self._object.on_trait_change(self._values_changed, " " + self._name, dispatch="ui") else: self._value = lambda: self.factory.values self.values_changed() factory.on_trait_change(self._values_changed, "values", dispatch="ui") def dispose(self): """ Disposes of the contents of an editor. """ if self._object is not None: self._object.on_trait_change(self._values_changed, " " + self._name, remove=True) else: self.factory.on_trait_change(self._values_changed, "values", remove=True) super(BaseEditor, self).dispose() # ------------------------------------------------------------------------- # Private interface # ------------------------------------------------------------------------- # Trait default handlers ------------------------------------------------- def _get_names(self): """ Gets the current set of enumeration names. """ return self._names def _get_mapping(self): """ Gets the current mapping. """ return self._mapping def _get_inverse_mapping(self): """ Gets the current inverse mapping. """ return self._inverse_mapping # Trait change handlers -------------------------------------------------- def _values_changed(self): """ Handles the underlying object model's enumeration set or factory's values being changed. """ self.values_changed() self.rebuild_editor()
class FERefinementLevel(FESubDomain): # specialized label _tree_label = Str('refinement level') def _set_domain(self, value): 'reset the domain of this domain' if self.parent != None: raise TraitError, 'child FESubDomain cannot be added to FEDomain' super(FERefinementLevel, self)._set_domain(value) def _get_domain(self): if self.parent != None: return self.parent.domain return super(FERefinementLevel, self)._get_domain() def validate(self): if self.parent != None: raise ValueError, 'only parentless subdomains can be inserted into domain' # children domains: list of the instances of the same class children = List(This) # parent domain _parent = This(domain_changed=True) parent = Property(This) def _set_parent(self, value): 'reset the parent of this domain' if self._parent: # check to see that the changed parent # is within the same domain if value.domain != self._parent.domain: raise NotImplementedError, 'Parent change across domains not implemented' # unregister in the old parent self._parent.children.remove(self) else: # append the current subdomain at the end of the subdomain # series within the domain #pass value.domain._append_in_series(self) # set the new parent self._parent = value # register the subdomain in the new parent self._parent.children.append(self) def _get_parent(self): return self._parent #--------------------------------------------------------------------------------------- # Implement the child interface #--------------------------------------------------------------------------------------- # Element refinement representation # refinement_dict = Dict(changed_structure=True) def refine_elem(self, parent_ix, *refinement_args): '''For the specified parent position let the new element decompose. ''' if self.refinement_dict.has_key(parent_ix): raise ValueError, 'element %s already refined' % ` parent_ix ` # the element is deactivated in the parent domain self.refinement_dict[parent_ix] = refinement_args self.parent.deactivate(parent_ix)
class DataRange1D(BaseDataRange): """ Represents a 1-D data range. """ # The actual value of the lower bound of this range (overrides # AbstractDataRange). To set it, use **low_setting**. low = Property # The actual value of the upper bound of this range (overrides # AbstractDataRange). To set it, use **high_setting**. high = Property # Property for the lower bound of this range (overrides AbstractDataRange). # # * 'auto': The lower bound is automatically set at or below the minimum # of the data. # * 'track': The lower bound tracks the upper bound by **tracking_amount**. # * CFloat: An explicit value for the lower bound low_setting = Property(Trait('auto', 'auto', 'track', CFloat)) # Property for the upper bound of this range (overrides AbstractDataRange). # # * 'auto': The upper bound is automatically set at or above the maximum # of the data. # * 'track': The upper bound tracks the lower bound by **tracking_amount**. # * CFloat: An explicit value for the upper bound high_setting = Property(Trait('auto', 'auto', 'track', CFloat)) # Do "auto" bounds imply an exact fit to the data? If False, # they pad a little bit of margin on either side. tight_bounds = Bool(True) # A user supplied function returning the proper bounding interval. # bounds_func takes (data_low, data_high, margin, tight_bounds) # and returns (low, high) bounds_func = Callable # The amount of margin to place on either side of the data, expressed as # a percentage of the full data width margin = Float(0.05) # The minimum percentage difference between low and high. That is, # (high-low) >= epsilon * low. # Used to be 1.0e-20 but chaco cannot plot at such a precision! epsilon = CFloat(1.0e-10) # When either **high** or **low** tracks the other, track by this amount. default_tracking_amount = CFloat(20.0) # The current tracking amount. This value changes with zooming. tracking_amount = default_tracking_amount # Default tracking state. This value is used when self.reset() is called. # # * 'auto': Both bounds reset to 'auto'. # * 'high_track': The high bound resets to 'track', and the low bound # resets to 'auto'. # * 'low_track': The low bound resets to 'track', and the high bound # resets to 'auto'. default_state = Enum('auto', 'high_track', 'low_track') # FIXME: this attribute is not used anywhere, is it safe to remove it? # Is this range dependent upon another range? fit_to_subset = Bool(False) #------------------------------------------------------------------------ # Private traits #------------------------------------------------------------------------ # The "_setting" attributes correspond to what the user has "set"; the # "_value" attributes are the actual numerical values for the given # setting. # The user-specified low setting. _low_setting = Trait('auto', 'auto', 'track', CFloat) # The actual numerical value for the low setting. _low_value = CFloat(-inf) # The user-specified high setting. _high_setting = Trait('auto', 'auto', 'track', CFloat) # The actual numerical value for the high setting. _high_value = CFloat(inf) # A list of attributes to persist # _pickle_attribs = ("_low_setting", "_high_setting") #------------------------------------------------------------------------ # AbstractRange interface #------------------------------------------------------------------------ def clip_data(self, data): """ Returns a list of data values that are within the range. Implements AbstractDataRange. """ return compress(self.mask_data(data), data) def mask_data(self, data): """ Returns a mask array, indicating whether values in the given array are inside the range. Implements AbstractDataRange. """ return ((data.view(ndarray) >= self._low_value) & (data.view(ndarray) <= self._high_value)) def bound_data(self, data): """ Returns a tuple of indices for the start and end of the first run of *data* that falls within the range. Implements AbstractDataRange. """ mask = self.mask_data(data) runs = arg_find_runs(mask, "flat") # Since runs of "0" are also considered runs, we have to cycle through # until we find the first run of "1"s. for run in runs: if mask[run[0]] == 1: # arg_find_runs returns 1 past the end return run[0], run[1] - 1 return (0, 0) def set_bounds(self, low, high): """ Sets all the bounds of the range simultaneously. Implements AbstractDataRange. """ if low == 'track': # Set the high setting first result_high = self._do_set_high_setting(high, fire_event=False) result_low = self._do_set_low_setting(low, fire_event=False) result = result_low or result_high else: # Either set low first or order doesn't matter result_low = self._do_set_low_setting(low, fire_event=False) result_high = self._do_set_high_setting(high, fire_event=False) result = result_high or result_low if result: self.updated = result def scale_tracking_amount(self, multiplier): """ Sets the **tracking_amount** to a new value, scaled by *multiplier*. """ self.tracking_amount = self.tracking_amount * multiplier self._do_track() def set_tracking_amount(self, amount): """ Sets the **tracking_amount** to a new value, *amount*. """ self.tracking_amount = amount self._do_track() def set_default_tracking_amount(self, amount): """ Sets the **default_tracking_amount** to a new value, *amount*. """ self.default_tracking_amount = amount #------------------------------------------------------------------------ # Public methods #------------------------------------------------------------------------ def reset(self): """ Resets the bounds of this range, based on **default_state**. """ # need to maintain 'track' setting if self.default_state == 'auto': self._high_setting = 'auto' self._low_setting = 'auto' elif self.default_state == 'low_track': self._high_setting = 'auto' self._low_setting = 'track' elif self.default_state == 'high_track': self._high_setting = 'track' self._low_setting = 'auto' self._refresh_bounds() self.tracking_amount = self.default_tracking_amount def refresh(self): """ If any of the bounds is 'auto', this method refreshes the actual low and high values from the set of the view filters' data sources. """ if ('auto' in (self._low_setting, self._high_setting)) or \ ('track' in (self._low_setting, self._high_setting)): # If the user has hard-coded bounds, then refresh() doesn't do # anything. self._refresh_bounds() else: return #------------------------------------------------------------------------ # Private methods (getters and setters) #------------------------------------------------------------------------ def _get_low(self): return float(self._low_value) def _set_low(self, val): return self._set_low_setting(val) def _get_low_setting(self): return self._low_setting def _do_set_low_setting(self, val, fire_event=True): """ Returns ------- If fire_event is False and the change would have fired an event, returns the tuple of the new low and high values. Otherwise returns None. In particular, if fire_event is True, it always returns None. """ new_values = None if self._low_setting != val: # Save the new setting. self._low_setting = val # If val is 'auto' or 'track', get the corresponding numerical # value. if val == 'auto': if len(self.sources) > 0: val = min( [source.get_bounds()[0] for source in self.sources]) else: val = -inf elif val == 'track': if len(self.sources) > 0 or self._high_setting != 'auto': val = self._high_value - self.tracking_amount else: val = -inf # val is now a numerical value. If it is the same as the current # value, there is nothing to do. if self._low_value != val: self._low_value = val if self._high_setting == 'track': self._high_value = val + self.tracking_amount if fire_event: self.updated = (self._low_value, self._high_value) else: new_values = (self._low_value, self._high_value) return new_values def _set_low_setting(self, val): self._do_set_low_setting(val, True) def _get_high(self): return float(self._high_value) def _set_high(self, val): return self._set_high_setting(val) def _get_high_setting(self): return self._high_setting def _do_set_high_setting(self, val, fire_event=True): """ Returns ------- If fire_event is False and the change would have fired an event, returns the tuple of the new low and high values. Otherwise returns None. In particular, if fire_event is True, it always returns None. """ new_values = None if self._high_setting != val: # Save the new setting. self._high_setting = val # If val is 'auto' or 'track', get the corresponding numerical # value. if val == 'auto': if len(self.sources) > 0: val = max( [source.get_bounds()[1] for source in self.sources]) else: val = inf elif val == 'track': if len(self.sources) > 0 or self._low_setting != 'auto': val = self._low_value + self.tracking_amount else: val = inf # val is now a numerical value. If it is the same as the current # value, there is nothing to do. if self._high_value != val: self._high_value = val if self._low_setting == 'track': self._low_value = val - self.tracking_amount if fire_event: self.updated = (self._low_value, self._high_value) else: new_values = (self._low_value, self._high_value) return new_values def _set_high_setting(self, val): self._do_set_high_setting(val, True) def _refresh_bounds(self): null_bounds = False if len(self.sources) == 0: null_bounds = True else: bounds_list = [source.get_bounds() for source in self.sources \ if source.get_size() > 0] if len(bounds_list) == 0: null_bounds = True if null_bounds: # If we have no sources and our settings are "auto", then reset our # bounds to infinity; otherwise, set the _value to the corresponding # setting. if (self._low_setting in ("auto", "track")): self._low_value = -inf else: self._low_value = self._low_setting if (self._high_setting in ("auto", "track")): self._high_value = inf else: self._high_value = self._high_setting return else: mins, maxes = list(zip(*bounds_list)) low_start, high_start = \ calc_bounds(self._low_setting, self._high_setting, mins, maxes, self.epsilon, self.tight_bounds, margin=self.margin, track_amount=self.tracking_amount, bounds_func=self.bounds_func) if (self._low_value != low_start) or (self._high_value != high_start): self._low_value = low_start self._high_value = high_start self.updated = (self._low_value, self._high_value) return def _do_track(self): changed = False if self._low_setting == 'track': new_value = self._high_value - self.tracking_amount if self._low_value != new_value: self._low_value = new_value changed = True elif self._high_setting == 'track': new_value = self._low_value + self.tracking_amount if self._high_value != new_value: self._high_value = new_value changed = True if changed: self.updated = (self._low_value, self._high_value) #------------------------------------------------------------------------ # Event handlers #------------------------------------------------------------------------ def _sources_items_changed(self, event): self.refresh() for source in event.removed: source.on_trait_change(self.refresh, "data_changed", remove=True) for source in event.added: source.on_trait_change(self.refresh, "data_changed") def _sources_changed(self, old, new): self.refresh() for source in old: source.on_trait_change(self.refresh, "data_changed", remove=True) for source in new: source.on_trait_change(self.refresh, "data_changed") #------------------------------------------------------------------------ # Serialization interface #------------------------------------------------------------------------ def _post_load(self): self._sources_changed(None, self.sources)
class ToolkitEditorFactory(EditorFactory): """ Editor factory for buttons. """ # ------------------------------------------------------------------------- # Trait definitions: # ------------------------------------------------------------------------- # Value to set when the button is clicked value = Property() # Optional label for the button label = Str() # The name of the external object trait that the button label is synced to label_value = Str() # The name of the trait on the object that contains the list of possible # values. If this is set, then the value, label, and label_value traits # are ignored; instead, they will be set from this list. When this button # is clicked, the value set will be the one selected from the drop-down. values_trait = Either(None, Str) # (Optional) Image to display on the button image = Image # Extra padding to add to both the left and the right sides width_padding = Range(0, 31, 7) # Extra padding to add to both the top and the bottom sides height_padding = Range(0, 31, 5) # Presentation style style = Enum("button", "radio", "toolbar", "checkbox") # Orientation of the text relative to the image orientation = Enum("vertical", "horizontal") # The optional view to display when the button is clicked: view = AView # ------------------------------------------------------------------------- # Traits view definition: # ------------------------------------------------------------------------- traits_view = View(["label", "value", "|[]"]) def _get_value(self): return self._value def _set_value(self, value): self._value = value if isinstance(value, str): try: self._value = int(value) except: try: self._value = float(value) except: pass def __init__(self, **traits): self._value = 0 super(ToolkitEditorFactory, self).__init__(**traits)
class ScatterPlot(BaseXYPlot): """ Renders a scatter plot, given an index and value arrays. """ # The CompiledPath to use if **marker** is set to "custom". This attribute # must be a compiled path for the Kiva context onto which this plot will # be rendered. Usually, importing kiva.GraphicsContext will do # the right thing. custom_symbol = Any #------------------------------------------------------------------------ # Styles on a ScatterPlot #------------------------------------------------------------------------ # The type of marker to use. This is a mapped trait using strings as the # keys. marker = MarkerTrait # The pixel size of the markers, not including the thickness of the outline. # Default value is 4.0. # TODO: for consistency, there should be a size data source and a mapper marker_size = Either(Float, Array) # The function which actually renders the markers render_markers_func = Callable(render_markers) # The thickness, in pixels, of the outline to draw around the marker. If # this is 0, no outline is drawn. line_width = Float(1.0) # The fill color of the marker. color = black_color_trait # The color of the outline to draw around the marker. outline_color = black_color_trait # The RGBA tuple for rendering lines. It is always a tuple of length 4. # It has the same RGB values as color_, and its alpha value is the alpha # value of self.color multiplied by self.alpha. effective_color = Property(Tuple, depends_on=['color', 'alpha']) # The RGBA tuple for rendering the fill. It is always a tuple of length 4. # It has the same RGB values as outline_color_, and its alpha value is the # alpha value of self.outline_color multiplied by self.alpha. effective_outline_color = Property(Tuple, depends_on=['outline_color', 'alpha']) # Traits UI View for customizing the plot. traits_view = ScatterPlotView() #------------------------------------------------------------------------ # Selection and selection rendering # A selection on the lot is indicated by setting the index or value # datasource's 'selections' metadata item to a list of indices, or the # 'selection_mask' metadata to a boolean array of the same length as the # datasource. #------------------------------------------------------------------------ show_selection = Bool(True) selection_marker = MarkerTrait selection_marker_size = Float(4.0) selection_line_width = Float(1.0) selection_color = ColorTrait("yellow") selection_outline_color = black_color_trait #------------------------------------------------------------------------ # Private traits #------------------------------------------------------------------------ _cached_selected_pts = Trait(None, None, Array) _cached_selected_screen_pts = Array _cached_point_mask = Array _cached_selection_point_mask = Array _selection_cache_valid = Bool(False) #------------------------------------------------------------------------ # Overridden PlotRenderer methods #------------------------------------------------------------------------ def map_screen(self, data_array): """ Maps an array of data points into screen space and returns it as an array. Implements the AbstractPlotRenderer interface. """ # data_array is Nx2 array if len(data_array) == 0: return [] # XXX: For some reason, doing the tuple unpacking doesn't work: # x_ary, y_ary = transpose(data_array) # There is a mysterious error "object of too small depth for # desired array". However, if you catch this exception and # try to execute the very same line of code again, it works # without any complaints. # # For now, we just use slicing to assign the X and Y arrays. data_array = asarray(data_array) if len(data_array.shape) == 1: x_ary = data_array[0] y_ary = data_array[1] else: x_ary = data_array[:, 0] y_ary = data_array[:, 1] sx = self.index_mapper.map_screen(x_ary) sy = self.value_mapper.map_screen(y_ary) if self.orientation == "h": return transpose(array((sx, sy))) else: return transpose(array((sy, sx))) def map_data(self, screen_pt, all_values=True): """ Maps a screen space point into the "index" space of the plot. Overrides the BaseXYPlot implementation, and always returns an array of (index, value) tuples. """ x, y = screen_pt if self.orientation == 'v': x, y = y, x return array( (self.index_mapper.map_data(x), self.value_mapper.map_data(y))) def map_index(self, screen_pt, threshold=0.0, outside_returns_none=True, \ index_only = False): """ Maps a screen space point to an index into the plot's index array(s). Overrides the BaseXYPlot implementation.. """ if index_only and self.index.sort_order != "none": data_pt = self.map_data(screen_pt)[0] # The rest of this was copied out of BaseXYPlot. # We can't just used BaseXYPlot.map_index because # it expect map_data to return a value, not a pair. if ((data_pt < self.index_mapper.range.low) or \ (data_pt > self.index_mapper.range.high)) and outside_returns_none: return None index_data = self.index.get_data() value_data = self.value.get_data() if len(value_data) == 0 or len(index_data) == 0: return None try: ndx = reverse_map_1d(index_data, data_pt, self.index.sort_order) except IndexError, e: # if reverse_map raises this exception, it means that data_pt is # outside the range of values in index_data. if outside_returns_none: return None else: if data_pt < index_data[0]: return 0 else: return len(index_data) - 1 if threshold == 0.0: # Don't do any threshold testing return ndx x = index_data[ndx] y = value_data[ndx] if isnan(x) or isnan(y): return None sx, sy = self.map_screen([x, y]) if ((threshold == 0.0) or (screen_pt[0] - sx) < threshold): return ndx else: return None else:
class TableFilterEditor(HasTraits): """ An editor that manages table filters. """ #------------------------------------------------------------------------- # Trait definitions: #------------------------------------------------------------------------- # TableEditor this editor is associated with editor = Instance(TableEditor) # The list of filters filters = List(TableFilter) # The list of available templates from which filters can be created templates = Property(List(TableFilter), depends_on='filters') # The currently selected filter template selected_template = Instance(TableFilter) # The currently selected filter selected_filter = Instance(TableFilter, allow_none=True) # The view to use for the current filter selected_filter_view = Property(depends_on='selected_filter') # Buttons for add/removing filters add_button = Button('New') remove_button = Button('Delete') # The default view for this editor view = View(Group(Group(Group(Item('add_button', enabled_when='selected_template'), Item('remove_button', enabled_when='len(templates) > 1 and ' 'selected_filter is not None'), orientation='horizontal', show_labels=False), Label('Base filter for new filters:'), Item('selected_template', editor=EnumEditor(name='templates')), Item('selected_filter', style='custom', editor=EnumEditor(name='filters', mode='list')), show_labels=False), Item('selected_filter', width=0.75, style='custom', editor=InstanceEditor(view_name='selected_filter_view')), id='TableFilterEditorSplit', show_labels=False, layout='split', orientation='horizontal'), id='traitsui.qt4.table_editor.TableFilterEditor', buttons=['OK', 'Cancel'], kind='livemodal', resizable=True, width=800, height=400, title='Customize filters') #------------------------------------------------------------------------- # Private methods: #------------------------------------------------------------------------- #-- Trait Property getter/setters ---------------------------------------- @cached_property def _get_selected_filter_view(self): view = None if self.selected_filter: model = self.editor.model index = model.mapToSource(model.index(0, 0)) if index.isValid(): obj = self.editor.items()[index.row()] else: obj = None view = self.selected_filter.edit_view(obj) return view @cached_property def _get_templates(self): templates = [f for f in self.editor.factory.filters if f.template] templates.extend(self.filters) return templates #-- Trait Change Handlers ------------------------------------------------ def _editor_changed(self): self.filters = [f.clone_traits() for f in self.editor.factory.filters if not f.template] self.selected_template = self.templates[0] def _add_button_fired(self): """ Create a new filter based on the selected template and select it. """ new_filter = self.selected_template.clone_traits() new_filter.template = False new_filter.name = new_filter._name = 'New filter' self.filters.append(new_filter) self.selected_filter = new_filter def _remove_button_fired(self): """ Delete the currently selected filter. """ if self.selected_template == self.selected_filter: self.selected_template = self.templates[0] index = self.filters.index(self.selected_filter) del self.filters[index] if index < len(self.filters): self.selected_filter = self.filters[index] else: self.selected_filter = None @on_trait_change('selected_filter:name') def _update_filter_list(self): """ A hack to make the EnumEditor watching the list of filters refresh their text when the name of the selected filter changes. """ filters = self.filters self.filters = [] self.filters = filters
class InterpolationEditor(GraphEditor): tool_klass = InterpolationFitSelector references = List auto_find = Bool(True) show_current = Bool(True) default_reference_analysis_type = 'air' sorted_analyses = Property(depends_on='analyses[]') sorted_references = Property(depends_on='references[]') binned_analyses = List bounds = List calculate_reference_age = Bool(False) _normalization_factor = 3600. def _get_min_max(self): mi = min(self.sorted_references[0].timestamp, self.sorted_analyses[0].timestamp) ma = max(self.sorted_references[-1].timestamp, self.sorted_analyses[-1].timestamp) return mi, ma def bin_analyses(self): groups = list(bin_analyses(self.analyses)) self.binned_analyses = [] self.bounds = [] n = len(groups) if n > 1: mi, ma = self._get_min_max() mi -= 1 ma += 1 bounds = get_bounds(groups) self.bounds = map(lambda x: x - mi, bounds) gs = [] low = None for i, gi in enumerate(groups): if low is None: low = mi try: high = bounds[i] except IndexError: high = ma refs = filter(lambda x: low < x.timestamp < high, self.sorted_references) gs.append( BinGroup(unknowns=gi, references=refs, bounds=((low - mi) / self._normalization_factor, (high - mi) / self._normalization_factor, i == 0, i == n - 1))) low = high self.binned_analyses = gs self.rebuild_graph() def rebuild_graph(self): super(InterpolationEditor, self).rebuild_graph() if self.bounds: for bi in self.bounds: self.add_group_divider(bi) def add_group_divider(self, cen): self.graph.add_vertical_rule(cen / self._normalization_factor, line_width=1.5, color='lightblue', line_style='solid') self.graph.redraw() def find_references(self): self._find_references() @on_trait_change('references[]') def _update_references(self): self._update_references_hook() self.rebuild_graph() def _update_references_hook(self): pass def _get_start_end(self, rxs, uxs): mrxs = min(rxs) if rxs else Inf muxs = min(uxs) if uxs else Inf marxs = max(rxs) if rxs else -Inf mauxs = max(uxs) if uxs else -Inf start = min(mrxs, muxs) end = max(marxs, mauxs) return start, end def set_auto_find(self, f): self.auto_find = f def _update_analyses_hook(self): self.debug('update analyses hook auto_find={}'.format(self.auto_find)) if self.auto_find: self._find_references() def set_references(self, refs, is_append=False, **kw): ans = self.processor.make_analyses( refs, # calculate_age=self.calculate_reference_age, # unpack=self.unpack_peaktime, **kw) if is_append: pans = self.references pans.extend(ans) ans = pans self.references = ans def _find_references(self): refs = self.processor.find_references() self.references = refs def set_interpolated_values(self, iso, reg, ans): mi, ma = self._get_min_max() if ans is None: ans = self.sorted_analyses xs = [(ai.timestamp - mi) / self._normalization_factor for ai in ans] p_uys = reg.predict(xs) p_ues = reg.predict_error(xs) self._set_interpolated_values(iso, ans, p_uys, p_ues) return p_uys, p_ues def _set_interpolated_values(self, *args, **kw): pass def _get_current_values(self, *args, **kw): pass def _get_reference_values(self, *args, **kw): pass def _get_isotope(self, ui, k, kind=None): if k in ui.isotopes: v = ui.isotopes[k] if kind is not None: v = getattr(v, kind) v = v.value, v.error else: v = 0, 0 return v def _rebuild_graph(self): graph = self.graph uxs = [ui.timestamp for ui in self.analyses] rxs = [ui.timestamp for ui in self.references] # display_xs = asarray(map(convert_timestamp, rxs[:])) start, end = self._get_start_end(rxs, uxs) c_uxs = self.normalize(uxs, start) r_xs = self.normalize(rxs, start) ''' c_... current value r... reference value p_... predicted value ''' gen = self._graph_generator() for i, fit in enumerate(gen): iso = fit.name fit = fit.fit_tuple() # print i, fit, self.binned_analyses if self.binned_analyses: self._build_binned(i, iso, fit, start) else: self._build_non_binned(i, iso, fit, c_uxs, r_xs) if i == 0: self._add_legend() m = abs(end - start) / self._normalization_factor graph.set_x_limits(0, m, pad='0.1') graph.refresh() def _build_binned(self, i, iso, fit, start): # graph=self.graph # iso = fit.name # fit = fit.fit.lower() # set_x_flag = True graph = self.graph p = graph.new_plot(ytitle=iso, xtitle='Time (hrs)', padding=[80, 10, 5, 30]) p.y_axis.title_spacing = 60 p.value_range.tight_bounds = False for j, gi in enumerate(self.binned_analyses): bis = gi.unknowns refs = gi.references c_xs = self.normalize([bi.timestamp for bi in bis], start) rxs = [bi.timestamp for bi in refs] r_xs = self.normalize(rxs, start) dx = asarray(map(convert_timestamp, rxs)) c_ys, c_es = None, None if self.show_current: c_ys, c_es = self._get_current_values(iso, bis) r_ys, r_es = None, None if refs: r_ys, r_es = self._get_reference_values(iso, refs) current_args = bis, c_xs, c_ys, c_es ref_args = refs, r_xs, r_ys, r_es, dx self._build_plot(i, iso, fit, current_args, ref_args, series_id=j, regression_bounds=gi.bounds) def _build_non_binned(self, i, iso, fit, c_xs, r_xs): c_ys, c_es = None, None if self.analyses and self.show_current: c_ys, c_es = self._get_current_values(iso) r_ys, r_es, dx = None, None, None if self.references: r_ys, r_es = self._get_reference_values(iso) dx = asarray( map(convert_timestamp, [ui.timestamp for ui in self.references])) current_args = self.sorted_analyses, c_xs, c_ys, c_es ref_args = self.sorted_references, r_xs, r_ys, r_es, dx graph = self.graph p = graph.new_plot( ytitle=iso, xtitle='Time (hrs)', # padding=[80, 10, 5, 30] padding=[80, 80, 80, 80]) p.y_axis.title_spacing = 60 p.value_range.tight_bounds = False self._build_plot(i, iso, fit, current_args, ref_args) def _plot_unknowns_current(self, ans, c_es, c_xs, c_ys, i): graph = self.graph if c_es and c_ys: # plot unknowns s, _p = graph.new_series(c_xs, c_ys, yerror=c_es, fit=False, type='scatter', plotid=i, marker='square', marker_size=3, bind_id=-1, color='black', add_inspector=False) self._add_inspector(s, ans) self._add_error_bars(s, c_es) graph.set_series_label('Unknowns-Current', plotid=i) def _plot_interpolated(self, ans, c_xs, i, iso, reg, series_id): p_uys, p_ues = self.set_interpolated_values(iso, reg, ans) if len(p_uys): graph = self.graph # display the predicted values s, p = graph.new_series(c_xs, p_uys, isotope=iso, yerror=ArrayDataSource(p_ues), fit=False, add_tools=False, add_inspector=False, type='scatter', marker_size=3, color='blue', plotid=i, bind_id=-1) series = len(p.plots) - 1 graph.set_series_label('Unknowns-predicted{}'.format(series_id), plotid=i, series=series) self._add_error_bars(s, p_ues) def _plot_references(self, ref_args, fit, i, regression_bounds, series_id): refs, r_xs, r_ys, r_es, display_xs = ref_args graph = self.graph if not r_ys: return # plot references efit = fit[0] if efit in [ 'preceding', 'bracketing interpolate', 'bracketing average' ]: reg = InterpolationRegressor(xs=r_xs, ys=r_ys, yserr=r_es, kind=efit) s, _p = graph.new_series(r_xs, r_ys, yerror=r_es, type='scatter', plotid=i, fit=False, marker_size=3, color='red', add_inspector=False) self._add_inspector(s, refs) self._add_error_bars(s, r_es) # series_id = (series_id+1) * 2 else: # series_id = (series_id+1) * 3 _p, s, l = graph.new_series( r_xs, r_ys, display_index=ArrayDataSource(data=display_xs), yerror=ArrayDataSource(data=r_es), fit=fit, color='red', plotid=i, marker_size=3, add_inspector=False) if hasattr(l, 'regressor'): reg = l.regressor l.regression_bounds = regression_bounds self._add_inspector(s, refs) self._add_error_bars(s, array(r_es)) return reg def _build_plot(self, i, iso, fit, current_args, ref_args, series_id=0, regression_bounds=None): ans, c_xs, c_ys, c_es = current_args self._plot_unknowns_current(ans, c_es, c_xs, c_ys, i) reg = self._plot_references(ref_args, fit, i, regression_bounds, series_id) if reg: self._plot_interpolated(ans, c_xs, i, iso, reg, series_id) # self._plot_unknowns_current(ans, c_es, c_xs, c_ys, i) def _add_legend(self): pass def _add_error_bars(self, scatter, errors, orientation='y', visible=True, nsigma=1, line_width=1): from pychron.graph.error_bar_overlay import ErrorBarOverlay ebo = ErrorBarOverlay(component=scatter, orientation=orientation, nsigma=nsigma, line_width=line_width, visible=visible) scatter.underlays.append(ebo) # print len(errors),scatter.index.get_size() setattr(scatter, '{}error'.format(orientation), ArrayDataSource(errors)) return ebo def _add_inspector(self, scatter, ans): broadcaster = BroadcasterTool() scatter.tools.append(broadcaster) rect_tool = RectSelectionTool(scatter) rect_overlay = RectSelectionOverlay(component=scatter, tool=rect_tool) scatter.overlays.append(rect_overlay) broadcaster.tools.append(rect_tool) point_inspector = AnalysisPointInspector( scatter, value_format=floatfmt, analyses=ans, convert_index=lambda x: '{:0.3f}'.format(x)) pinspector_overlay = PointInspectorOverlay(component=scatter, tool=point_inspector) scatter.overlays.append(pinspector_overlay) scatter.tools.append(point_inspector) broadcaster.tools.append(point_inspector) scatter.index.on_trait_change(self._update_metadata(ans), 'metadata_changed') # scatter.index.on_trait_change(self._test, 'metadata_changed') # def _test(self, obj, name, old, new): # import inspect # stack = inspect.stack() # print stack # print '{} called by {}'.format('test', stack[1][3]) # meta = obj.metadata # print meta def _update_metadata(self, ans): def _handler(obj, name, old, new): meta = obj.metadata selections = meta['selections'] for i, ai in enumerate(ans): ai.temp_status = i in selections return _handler @cached_property def _get_sorted_analyses(self): return sorted(self.analyses, key=lambda x: x.timestamp) @cached_property def _get_sorted_references(self): return sorted(self.references, key=lambda x: x.timestamp) @on_trait_change('graph:regression_results') def _update_regression(self, new): # return key = 'Unknowns-predicted{}' #necessary to handle user excluding points if self.binned_analyses: gen = self._graph_generator() c = 0 for j, fit in enumerate(gen): for i, g in enumerate(self.binned_analyses): try: plotobj, reg = new[c] except IndexError: break if issubclass(type(reg), BaseRegressor): k = key.format(i) self._set_values(fit, plotobj, reg, k, g.unknowns) c += 1 else: key = key.format(0) gen = self._graph_generator() for fit, (plotobj, reg) in zip(gen, new): if issubclass(type(reg), BaseRegressor): self._set_values(fit, plotobj, reg, key) def _set_values(self, fit, plotobj, reg, key, ans=None): iso = fit.name if key in plotobj.plots: scatter = plotobj.plots[key][0] p_uys, p_ues = self.set_interpolated_values(iso, reg, ans) scatter.value.set_data(p_uys) scatter.yerror.set_data(p_ues) def _clean_references(self): return [ri for ri in self.references if ri.temp_status == 0]
class Volume(Module): """The Volume module visualizes scalar fields using volumetric visualization techniques. This supports ImageData and UnstructuredGrid data. It also supports the FixedPointRenderer for ImageData. However, the performance is slow so your best bet is probably with the ImageData based renderers. """ # The version of this class. Used for persistence. __version__ = 0 volume_mapper_type = DEnum(values_name='_mapper_types', desc='volume mapper to use') ray_cast_function_type = DEnum(values_name='_ray_cast_functions', desc='Ray cast function to use') volume = ReadOnly volume_mapper = Property(record=True) volume_property = Property(record=True) ray_cast_function = Property(record=True) lut_manager = Instance(VolumeLUTManager, args=(), allow_none=False, record=True) input_info = PipelineInfo(datasets=['image_data', 'unstructured_grid'], attribute_types=['any'], attributes=['scalars']) ######################################## # View related code. update_ctf = Button('Update CTF') view = View(Group(Item(name='_volume_property', style='custom', editor=VolumePropertyEditor, resizable=True), Item(name='update_ctf'), label='CTF', show_labels=False), Group( Item(name='volume_mapper_type'), Group(Item(name='_volume_mapper', style='custom', resizable=True), show_labels=False), Item(name='ray_cast_function_type'), Group(Item(name='_ray_cast_function', enabled_when='len(_ray_cast_functions) > 0', style='custom', resizable=True), show_labels=False), label='Mapper', ), Group(Item(name='_volume_property', style='custom', resizable=True), label='Property', show_labels=False), Group(Item(name='volume', style='custom', editor=InstanceEditor(), resizable=True), label='Volume', show_labels=False), Group(Item(name='lut_manager', style='custom', resizable=True), label='Legend', show_labels=False), resizable=True) ######################################## # Private traits _volume_mapper = Instance(tvtk.AbstractVolumeMapper) _volume_property = Instance(tvtk.VolumeProperty) _ray_cast_function = Instance(tvtk.Object) _mapper_types = List(Str, [ 'TextureMapper2D', 'RayCastMapper', ]) _available_mapper_types = List(Str) _ray_cast_functions = List(Str) current_range = Tuple # The color transfer function. _ctf = Instance(ColorTransferFunction) # The opacity values. _otf = Instance(PiecewiseFunction) # A cache for the mappers, a dict keyed by class. _mapper_cache = Dict ###################################################################### # `object` interface ###################################################################### def __get_pure_state__(self): d = super(Volume, self).__get_pure_state__() d['ctf_state'] = save_ctfs(self._volume_property) for name in ('current_range', '_ctf', '_otf'): d.pop(name, None) return d def __set_pure_state__(self, state): self.volume_mapper_type = state['_volume_mapper_type'] state_pickler.set_state(self, state, ignore=['ctf_state']) ctf_state = state['ctf_state'] ctf, otf = load_ctfs(ctf_state, self._volume_property) self._ctf = ctf self._otf = otf self._update_ctf_fired() ###################################################################### # `Module` interface ###################################################################### def start(self): super(Volume, self).start() self.lut_manager.start() def stop(self): super(Volume, self).stop() self.lut_manager.stop() def setup_pipeline(self): """Override this method so that it *creates* the tvtk pipeline. """ v = self.volume = tvtk.Volume() vp = self._volume_property = tvtk.VolumeProperty() self._ctf = ctf = default_CTF(0, 255) self._otf = otf = default_OTF(0, 255) vp.set_color(ctf) vp.set_scalar_opacity(otf) vp.shade = True vp.interpolation_type = 'linear' v.property = vp v.on_trait_change(self.render) vp.on_trait_change(self.render) available_mappers = find_volume_mappers() if is_volume_pro_available(): self._mapper_types.append('VolumeProMapper') available_mappers.append('VolumeProMapper') self._available_mapper_types = available_mappers if 'FixedPointVolumeRayCastMapper' in available_mappers: self._mapper_types.append('FixedPointVolumeRayCastMapper') self.actors.append(v) def update_pipeline(self): """Override this method so that it *updates* the tvtk pipeline when data upstream is known to have changed. This method is invoked (automatically) when any of the inputs sends a `pipeline_changed` event. """ mm = self.module_manager if mm is None: return dataset = mm.source.get_output_dataset() ug = hasattr(tvtk, 'UnstructuredGridVolumeMapper') if dataset.is_a('vtkMultiBlockDataSet'): if 'MultiBlockVolumeMapper' not in self._available_mapper_types: error('Your version of VTK does not support ' 'Multiblock volume rendering') return elif dataset.is_a('vtkUniformGridAMR'): if 'AMRVolumeMapper' not in self._available_mapper_types: error('Your version of VTK does not support ' 'AMR volume rendering') return elif ug: if not dataset.is_a('vtkImageData') \ and not dataset.is_a('vtkUnstructuredGrid'): error('Volume rendering only works with ' 'StructuredPoints/ImageData/UnstructuredGrid datasets') return elif not dataset.is_a('vtkImageData'): error('Volume rendering only works with ' 'StructuredPoints/ImageData datasets') return self._setup_mapper_types() self._setup_current_range() self._volume_mapper_type_changed(self.volume_mapper_type) self._update_ctf_fired() self.pipeline_changed = True def update_data(self): """Override this method so that it flushes the vtk pipeline if that is necessary. This method is invoked (automatically) when any of the inputs sends a `data_changed` event. """ self._setup_mapper_types() self._setup_current_range() self._update_ctf_fired() self.data_changed = True ###################################################################### # Non-public methods. ###################################################################### def _get_image_data_volume_mappers(self): check = ('SmartVolumeMapper', 'GPUVolumeRayCastMapper', 'OpenGLGPUVolumeRayCastMapper') return [x for x in check if x in self._available_mapper_types] def _setup_mapper_types(self): """Sets up the mapper based on input data types. """ dataset = self.module_manager.source.get_output_dataset() if dataset.is_a('vtkMultiBlockDataSet'): if 'MultiBlockVolumeMapper' in self._available_mapper_types: self._mapper_types = ['MultiBlockVolumeMapper'] elif dataset.is_a('vtkUniformGridAMR'): if 'AMRVolumeMapper' not in self._available_mapper_types: self._mapper_types = ['AMRVolumeMapper'] elif dataset.is_a('vtkUnstructuredGrid'): if hasattr(tvtk, 'UnstructuredGridVolumeMapper'): check = [ 'UnstructuredGridVolumeZSweepMapper', 'UnstructuredGridVolumeRayCastMapper', ] mapper_types = [] for mapper in check: if mapper in self._available_mapper_types: mapper_types.append(mapper) if len(mapper_types) == 0: mapper_types = [''] self._mapper_types = mapper_types return else: mapper_types = self._get_image_data_volume_mappers() if dataset.point_data.scalars.data_type not in \ [vtkConstants.VTK_UNSIGNED_CHAR, vtkConstants.VTK_UNSIGNED_SHORT]: if 'FixedPointVolumeRayCastMapper' \ in self._available_mapper_types: mapper_types.append('FixedPointVolumeRayCastMapper') elif len(mapper_types) == 0: error('Available volume mappers only work with ' 'unsigned_char or unsigned_short datatypes') else: check = [ 'FixedPointVolumeRayCastMapper', 'VolumeProMapper', 'TextureMapper2D', 'RayCastMapper', 'TextureMapper3D' ] for mapper in check: if mapper in self._available_mapper_types: mapper_types.append(mapper) self._mapper_types = mapper_types def _setup_current_range(self): mm = self.module_manager # Set the default name and range for our lut. lm = self.lut_manager slm = mm.scalar_lut_manager lm.trait_set(default_data_name=slm.default_data_name, default_data_range=slm.default_data_range) # Set the current range. dataset = mm.source.get_output_dataset() dsh = DataSetHelper(dataset) name, rng = dsh.get_range('scalars', 'point') if name is None: error('No scalars in input data!') rng = (0, 255) if self.current_range != rng: self.current_range = rng def _get_volume_mapper(self): return self._volume_mapper def _get_volume_property(self): return self._volume_property def _get_ray_cast_function(self): return self._ray_cast_function def _get_mapper(self, klass): """ Return a mapper of the given class. Either from the cache or by making a new one. """ result = self._mapper_cache.get(klass) if result is None: result = klass() self._mapper_cache[klass] = result return result def _volume_mapper_type_changed(self, value): mm = self.module_manager if mm is None: return old_vm = self._volume_mapper if old_vm is not None: old_vm.on_trait_change(self.render, remove=True) try: old_vm.remove_all_input_connections(0) except AttributeError: pass if value == 'RayCastMapper': new_vm = self._get_mapper(tvtk.VolumeRayCastMapper) self._volume_mapper = new_vm self._ray_cast_functions = [ 'RayCastCompositeFunction', 'RayCastMIPFunction', 'RayCastIsosurfaceFunction' ] new_vm.volume_ray_cast_function = self._get_mapper( tvtk.VolumeRayCastCompositeFunction) elif value == 'MultiBlockVolumeMapper': new_vm = self._get_mapper(tvtk.MultiBlockVolumeMapper) self._volume_mapper = new_vm self._ray_cast_functions = [''] elif value == 'AMRVolumeMapper': new_vm = self._get_mapper(tvtk.AMRVolumeMapper) self._volume_mapper = new_vm self._ray_cast_functions = [''] elif value == 'SmartVolumeMapper': new_vm = self._get_mapper(tvtk.SmartVolumeMapper) self._volume_mapper = new_vm self._ray_cast_functions = [''] elif value == 'GPUVolumeRayCastMapper': new_vm = self._get_mapper(tvtk.GPUVolumeRayCastMapper) self._volume_mapper = new_vm self._ray_cast_functions = [''] elif value == 'OpenGLGPUVolumeRayCastMapper': new_vm = self._get_mapper(tvtk.OpenGLGPUVolumeRayCastMapper) self._volume_mapper = new_vm self._ray_cast_functions = [''] elif value == 'TextureMapper2D': new_vm = self._get_mapper(tvtk.VolumeTextureMapper2D) self._volume_mapper = new_vm self._ray_cast_functions = [''] elif value == 'TextureMapper3D': new_vm = self._get_mapper(tvtk.VolumeTextureMapper3D) self._volume_mapper = new_vm self._ray_cast_functions = [''] elif value == 'VolumeProMapper': new_vm = self._get_mapper(tvtk.VolumeProMapper) self._volume_mapper = new_vm self._ray_cast_functions = [''] elif value == 'FixedPointVolumeRayCastMapper': new_vm = self._get_mapper(tvtk.FixedPointVolumeRayCastMapper) self._volume_mapper = new_vm self._ray_cast_functions = [''] elif value == 'UnstructuredGridVolumeRayCastMapper': new_vm = self._get_mapper(tvtk.UnstructuredGridVolumeRayCastMapper) self._volume_mapper = new_vm self._ray_cast_functions = [''] elif value == 'UnstructuredGridVolumeZSweepMapper': new_vm = self._get_mapper(tvtk.UnstructuredGridVolumeZSweepMapper) self._volume_mapper = new_vm self._ray_cast_functions = [''] src = mm.source self.configure_input(new_vm, src.outputs[0]) self.volume.mapper = new_vm new_vm.on_trait_change(self.render) def _update_ctf_fired(self): set_lut(self.lut_manager.lut, self._volume_property) self.render() def _current_range_changed(self, old, new): rescale_ctfs(self._volume_property, new) self.render() def _ray_cast_function_type_changed(self, old, new): rcf = self.ray_cast_function if len(old) > 0: rcf.on_trait_change(self.render, remove=True) if len(new) > 0: new_rcf = getattr(tvtk, 'Volume%s' % new)() new_rcf.on_trait_change(self.render) self._volume_mapper.volume_ray_cast_function = new_rcf self._ray_cast_function = new_rcf else: self._ray_cast_function = None self.render() def _scene_changed(self, old, new): super(Volume, self)._scene_changed(old, new) self.lut_manager.scene = new
class NamespaceView(View): """ A view containing the contents of the Python shell namespace. """ #### 'IView' interface #################################################### # The part's globally unique identifier. id = "enthought.plugins.python_shell.view.namespace_view" # The view's name. name = "Namespace" # The default position of the view relative to the item specified in the # 'relative_to' trait. position = "left" #### 'NamespaceView' interface ############################################ # The bindings in the namespace. This is a list of HasTraits objects with # 'name', 'type' and 'module' string attributes. bindings = Property(List, observe="namespace") shell_view = Instance(PythonShellView) namespace = DelegatesTo("shell_view") # The default traits UI view. traits_view = TraitsView( VGroup( Item( "bindings", id="table", editor=table_editor, springy=True, resizable=True, ), show_border=True, show_labels=False, ), resizable=True, ) ########################################################################### # 'View' interface. ########################################################################### def create_control(self, parent): """ Creates the toolkit-specific control that represents the view. 'parent' is the toolkit-specific control that is the view's parent. """ self.ui = self.edit_traits(parent=parent, kind="subpanel") self.shell_view = self.window.application.get_service(IPythonShell) # 'shell_view' is an instance of the class PythonShellView from the # module envisage.plugins.python_shell.view.python_shell_view. return self.ui.control ########################################################################### # 'NamespaceView' interface. ########################################################################### #### Properties ########################################################### @cached_property def _get_bindings(self): """ Property getter. """ if self.shell_view is None: return [] class item(HasTraits): name = Str type = Str module = Str data = [ item( name=name, type=type_to_str(value), module=module_to_str(value) ) for name, value in self.shell_view.namespace.items() ] return data
class MRISubjectSource(HasPrivateTraits): """Find subjects in SUBJECTS_DIR and select one. Parameters ---------- subjects_dir : directory SUBJECTS_DIR. subject : str Subject, corresponding to a folder in SUBJECTS_DIR. """ refresh = Event(desc="Refresh the subject list based on the directory " "structure of subjects_dir.") # settings subjects_dir = Directory(exists=True) subjects = Property(List(Str), depends_on=['subjects_dir', 'refresh']) subject = Enum(values='subjects') use_high_res_head = Bool(True) # info can_create_fsaverage = Property(Bool, depends_on=['subjects_dir', 'subjects']) subject_has_bem = Property(Bool, depends_on=['subjects_dir', 'subject'], desc="whether the subject has a file matching " "the bem file name pattern") bem_pattern = Property(depends_on='mri_dir') @cached_property def _get_can_create_fsaverage(self): if not os.path.exists(self.subjects_dir): return False if 'fsaverage' in self.subjects: return False return True @cached_property def _get_mri_dir(self): if not self.subject: return elif not self.subjects_dir: return else: return os.path.join(self.subjects_dir, self.subject) @cached_property def _get_subjects(self): sdir = self.subjects_dir is_dir = sdir and os.path.isdir(sdir) if is_dir: dir_content = os.listdir(sdir) subjects = [s for s in dir_content if _is_mri_subject(s, sdir)] if len(subjects) == 0: subjects.append('') else: subjects = [''] return subjects @cached_property def _get_subject_has_bem(self): if not self.subject: return False return _mri_subject_has_bem(self.subject, self.subjects_dir) def create_fsaverage(self): # noqa: D102 if not self.subjects_dir: err = ("No subjects directory is selected. Please specify " "subjects_dir first.") raise RuntimeError(err) mne_root = get_mne_root() if mne_root is None: err = ("MNE contains files that are needed for copying the " "fsaverage brain. Please install MNE and try again.") raise RuntimeError(err) fs_home = get_fs_home() if fs_home is None: err = ("FreeSurfer contains files that are needed for copying the " "fsaverage brain. Please install FreeSurfer and try again.") raise RuntimeError(err) create_default_subject(mne_root, fs_home, subjects_dir=self.subjects_dir) self.refresh = True self.subject = 'fsaverage'
class Axis(ConfigLoadable): ''' ''' id = Int # name = Str position = Float negative_limit = Float positive_limit = Float pdir = Str parent = Any(transient=True) calculate_parameters = Bool(True) drive_ratio = Float(1) velocity = Property(depends_on='_velocity') _velocity = Float(enter_set=True, auto_set=False) acceleration = Property(depends_on='_acceleration') _acceleration = Float(enter_set=True, auto_set=False) deceleration = Property(depends_on='_deceleration') _deceleration = Float(enter_set=True, auto_set=False) machine_velocity = Float machine_acceleration = Float machine_deceleration = Float # sets handled by child class nominal_velocity = Float nominal_acceleration = Float nominal_deceleration = Float sign = CInt(1) def _get_velocity(self): return self._velocity def _get_acceleration(self): return self._acceleration def _get_deceleration(self): return self._deceleration def upload_parameters_to_device(self): pass @on_trait_change('_velocity, _acceleration, _deceleration') def update_machine_values(self, obj, name, old, new): setattr(self, 'machine{}'.format(name), new) def _calibration_changed(self): self.parent.update_axes() def simple_view(self): v = View( Item('calculate_parameters'), Item('velocity', format_str='%0.3f', enabled_when='not calculate_parameters'), Item('acceleration', format_str='%0.3f', enabled_when='not calculate_parameters'), Item('deceleration', format_str='%0.3f', enabled_when='not calculate_parameters'), Item('drive_ratio')) return v def full_view(self): return self.simple_view() def dump(self): ''' ''' pass # self.loaded = False # # p = os.path.join(self.pdir, '.%s' % self.name) # with open(p, 'w') as f: # pickle.dump(self, f) # def load_parameters_from_config(self, path): # self.config_path = path # self._load_parameters_from_config(path) # # def load_parameters(self, pdir): # ''' # ''' # # self.pdir = pdir # # p = os.path.join(pdir, '.%s' % self.name) # # # # if os.path.isfile(p): # # return p # # else: # self.load(pdir) def save(self): pass def ask(self, cmd): return self.parent.ask(cmd) def _get_parameters(self, path): ''' ''' # cp = ConfigParser.ConfigParser() # cp.read()) params = [] # if path is None: if not os.path.isfile(path): path = os.path.join(path, '{}axis.cfg'.format(self.name)) cp = self.get_configuration(path) if cp: params = [item for s in cp.sections() for item in cp.items(s)] # for ai in a: # print ai # # for s in cp.sections(): # for i in cp.items(s): # params.append(i) return params def _validate_float(self, v): try: v = float(v) return v except ValueError: pass