class AdaptationOffer(HasTraits):
    """ An offer to provide adapters from one protocol to another.

    An adaptation offer consists of a factory that can create adapters, and the
    protocols that define what the adapters adapt from and to.

    """

    #### 'object' protocol ####################################################

    def __repr__(self):
        """ Return a string representation of the object. """

        template = "<AdaptationOffer: '{from_}' -> '{to}'>"

        from_ = self.from_protocol_name
        to = self.to_protocol_name

        return template.format(from_=from_, to=to)

    #### 'AdaptationOffer' protocol ###########################################

    #: A factory for creating adapters.
    #:
    #: The factory must ba callable that takes exactly one argument which is
    #: the object to be adapted (known as the adaptee), and returns an
    #: adapter from the `from_protocol` to the `to_protocol`.
    #:
    #: The factory can be specified as either a callable, or a string in the
    #: form 'foo.bar.baz' which is turned into an import statement
    #: 'from foo.bar import baz' and imported when the trait is first accessed.
    factory = Property(Any)

    #: Adapters created by the factory adapt *from* this protocol.
    #:
    #: The protocol can be specified as a protocol (class/Interface), or a
    #: string in the form 'foo.bar.baz' which is turned into an import
    #: statement 'from foo.bar import baz' and imported when the trait is
    #: accessed.
    from_protocol = Property(Any)
    from_protocol_name = Property(Any)

    def _get_from_protocol_name(self):
        return self._get_type_name(self._from_protocol)

    #: Adapters created by the factory adapt *to* this protocol.
    #:
    #: The protocol can be specified as a protocol (class/Interface), or a
    #: string in the form 'foo.bar.baz' which is turned into an import
    #: statement 'from foo.bar import baz' and imported when the trait is
    #: accessed.
    to_protocol = Property(Any)
    to_protocol_name = Property(Any)

    def _get_to_protocol_name(self):
        return self._get_type_name(self._to_protocol)

    #### Private protocol ######################################################

    #: Shadow trait for the corresponding property.
    _factory = Any
    _factory_loaded = Bool(False)

    def _get_factory(self):
        """ Trait property getter. """

        if not self._factory_loaded:
            if isinstance(self._factory, str):
                self._factory = import_symbol(self._factory)

            self._factory_loaded = True

        return self._factory

    def _set_factory(self, factory):
        """ Trait property setter. """

        self._factory = factory

        return

    #: Shadow trait for the corresponding property.
    _from_protocol = Any
    _from_protocol_loaded = Bool(False)

    def _get_from_protocol(self):
        """ Trait property getter. """

        if not self._from_protocol_loaded:
            if isinstance(self._from_protocol, str):
                self._from_protocol = import_symbol(self._from_protocol)

            self._from_protocol_loaded = True

        return self._from_protocol

    def _set_from_protocol(self, from_protocol):
        """ Trait property setter. """

        self._from_protocol = from_protocol

        return

    #: Shadow trait for the corresponding property.
    _to_protocol = Any
    _to_protocol_loaded = Bool(False)

    def _get_to_protocol(self):
        """ Trait property getter. """

        if not self._to_protocol_loaded:
            if isinstance(self._to_protocol, str):
                self._to_protocol = import_symbol(self._to_protocol)

            self._to_protocol_loaded = True

        return self._to_protocol

    def _set_to_protocol(self, to_protocol):
        """ Trait property setter. """

        self._to_protocol = to_protocol

        return

    def _get_type_name(self, type_or_type_name):
        """ Returns the full dotted path for a type.

        For example:
        from traits.api import HasTraits
        _get_type_name(HasTraits) == 'traits.has_traits.HasTraits'

        If the type is given as a string (e.g., for lazy loading), it is just
        returned.

        """

        if isinstance(type_or_type_name, str):
            type_name = type_or_type_name

        else:
            type_name = "{module}.{name}".format(
                module=type_or_type_name.__module__,
                name=type_or_type_name.__name__)

        return type_name
Ejemplo n.º 2
0
class TableEditor(Editor, BaseTableEditor):
    """ Editor that presents data in a table. Optionally, tables can have
        a set of filters that reduce the set of data displayed, according to
        their criteria.
    """

    #-------------------------------------------------------------------------
    #  Trait definitions:
    #-------------------------------------------------------------------------

    # The table view control associated with the editor:
    table_view = Any

    def _table_view_default(self):
        return TableView(editor=self)

    # A wrapper around the source model which provides filtering and sorting:
    model = Instance(SortFilterTableModel)

    def _model_default(self):
        return SortFilterTableModel(editor=self)

    # The table model associated with the editor:
    source_model = Instance(TableModel)

    def _source_model_default(self):
        return TableModel(editor=self)

    # The set of columns currently defined on the editor:
    columns = List(TableColumn)

    # The currently selected row(s), column(s), or cell(s).
    selected = Any

    # The current selected row
    selected_row = Property(Any, depends_on='selected')

    selected_indices = Property(Any, depends_on='selected')

    # Current filter object (should be a TableFilter or callable or None):
    filter = Any

    # The indices of the table items currently passing the table filter:
    filtered_indices = List(Int)

    # Current filter summary message
    filter_summary = Str('All items')

    # Update the filtered contents.
    update_filter = Event()

    # The event fired when a cell is clicked on:
    click = Event

    # The event fired when a cell is double-clicked on:
    dclick = Event

    # The Traits UI associated with the table editor toolbar:
    toolbar_ui = Instance(UI)

    # The context menu associated with empty space in the table
    empty_menu = Instance(QtGui.QMenu)

    # The context menu associated with the vertical header
    header_menu = Instance(QtGui.QMenu)

    # The context menu actions for moving rows up and down
    header_menu_up = Instance(QtGui.QAction)
    header_menu_down = Instance(QtGui.QAction)

    # The index of the row that was last right clicked on its vertical header
    header_row = Int

    # Whether to auto-size the columns or not.
    auto_size = Bool(False)

    # Dictionary mapping image names to QIcons
    images = Any({})

    # Dictionary mapping ImageResource objects to QIcons
    image_resources = Any({})

    # An image being converted:
    image = Image

    #-------------------------------------------------------------------------
    #  Finishes initializing the editor by creating the underlying toolkit
    #  widget:
    #-------------------------------------------------------------------------

    def init(self, parent):
        """Finishes initializing the editor by creating the underlying toolkit
        widget."""

        factory = self.factory
        self.columns = factory.columns[:]
        if factory.table_view_factory is not None:
            self.table_view = factory.table_view_factory(editor=self)
        if factory.source_model_factory is not None:
            self.source_model = factory.source_model_factory(editor=self)
        if factory.model_factory is not None:
            self.model = factory.model_factory(editor=self)

        # Create the table view and model
        self.model.setDynamicSortFilter(True)
        self.model.setSourceModel(self.source_model)
        self.table_view.setModel(self.model)

        # Create the vertical header context menu and connect to its signals
        self.header_menu = QtGui.QMenu(self.table_view)
        signal = QtCore.SIGNAL('triggered()')
        insertable = factory.row_factory is not None and not factory.auto_add
        if factory.editable:
            if insertable:
                action = self.header_menu.addAction('Insert new item')
                QtCore.QObject.connect(action, signal, self._on_context_insert)
            if factory.deletable:
                action = self.header_menu.addAction('Delete item')
                QtCore.QObject.connect(action, signal, self._on_context_remove)
        if factory.reorderable:
            if factory.editable and (insertable or factory.deletable):
                self.header_menu.addSeparator()
            self.header_menu_up = self.header_menu.addAction('Move item up')
            QtCore.QObject.connect(self.header_menu_up, signal,
                                   self._on_context_move_up)
            self.header_menu_down = self.header_menu.addAction(
                'Move item down')
            QtCore.QObject.connect(self.header_menu_down, signal,
                                   self._on_context_move_down)

        # Create the empty space context menu and connect its signals
        self.empty_menu = QtGui.QMenu(self.table_view)
        action = self.empty_menu.addAction('Add new item')
        QtCore.QObject.connect(action, signal, self._on_context_append)

        # When sorting is enabled, the first column is initially displayed with
        # the triangle indicating it is the sort index, even though no sorting
        # has actually been done. Sort here for UI/model consistency.
        if self.factory.sortable and not self.factory.reorderable:
            self.model.sort(0, QtCore.Qt.AscendingOrder)

        # Connect to the mode specific selection handler and select the first
        # row/column/cell. Do this before creating the edit_view to make sure
        # that it has a valid item to use when constructing its view.
        smodel = self.table_view.selectionModel()
        signal = QtCore.SIGNAL(
            'selectionChanged(QItemSelection, QItemSelection)')
        mode_slot = getattr(self, '_on_%s_selection' % factory.selection_mode)
        QtCore.QObject.connect(smodel, signal, mode_slot)
        self.table_view.setCurrentIndex(self.model.index(0, 0))

        # Create the toolbar if necessary
        if factory.show_toolbar and len(factory.filters) > 0:
            main_view = QtGui.QWidget()
            layout = QtGui.QVBoxLayout(main_view)
            layout.setContentsMargins(0, 0, 0, 0)
            self.toolbar_ui = self.edit_traits(
                parent=parent,
                kind='subpanel',
                view=View(Group(Item('filter{View}',
                                     editor=factory._filter_editor),
                                Item('filter_summary{Results}',
                                     style='readonly'),
                                spring,
                                orientation='horizontal'),
                          resizable=True))
            self.toolbar_ui.parent = self.ui
            layout.addWidget(self.toolbar_ui.control)
            layout.addWidget(self.table_view)
        else:
            main_view = self.table_view

        # Create auxillary editor and encompassing splitter if necessary
        mode = factory.selection_mode
        if (factory.edit_view == ' ') or not mode in ('row', 'rows'):
            self.control = main_view
        else:
            if factory.orientation == 'horizontal':
                self.control = QtGui.QSplitter(QtCore.Qt.Horizontal)
            else:
                self.control = QtGui.QSplitter(QtCore.Qt.Vertical)
            self.control.setSizePolicy(QtGui.QSizePolicy.Expanding,
                                       QtGui.QSizePolicy.Expanding)
            self.control.addWidget(main_view)
            self.control.setStretchFactor(0, 2)

            # Create the row editor below the table view
            editor = InstanceEditor(view=factory.edit_view, kind='subpanel')
            self._ui = self.edit_traits(
                parent=self.control,
                kind='subpanel',
                view=View(Item('selected_row',
                               style='custom',
                               editor=editor,
                               show_label=False,
                               resizable=True,
                               width=factory.edit_view_width,
                               height=factory.edit_view_height),
                          resizable=True,
                          handler=factory.edit_view_handler))
            self._ui.parent = self.ui
            self.control.addWidget(self._ui.control)
            self.control.setStretchFactor(1, 1)

        # Connect to the click and double click handlers
        signal = QtCore.SIGNAL('clicked(QModelIndex)')
        QtCore.QObject.connect(self.table_view, signal, self._on_click)
        signal = QtCore.SIGNAL('doubleClicked(QModelIndex)')
        QtCore.QObject.connect(self.table_view, signal, self._on_dclick)

        # Make sure we listen for 'items' changes as well as complete list
        # replacements
        self.context_object.on_trait_change(
            self.update_editor, self.extended_name + '_items', dispatch='ui')

        # Listen for changes to traits on the objects in the list
        self.context_object.on_trait_change(
            self.refresh_editor, self.extended_name + '.-', dispatch='ui')

        # Listen for changes on column definitions
        self.on_trait_change(self._update_columns, 'columns', dispatch='ui')
        self.on_trait_change(self._update_columns, 'columns_items',
                             dispatch='ui')

        # Set up the required externally synchronized traits
        is_list = (mode in ('rows', 'columns', 'cells'))
        self.sync_value(factory.click, 'click', 'to')
        self.sync_value(factory.dclick, 'dclick', 'to')
        self.sync_value(factory.columns_name, 'columns', is_list=True)
        self.sync_value(factory.selected, 'selected', is_list=is_list)
        self.sync_value(
            factory.selected_indices,
            'selected_indices',
            is_list=is_list)
        self.sync_value(factory.filter_name, 'filter', 'from')
        self.sync_value(factory.filtered_indices, 'filtered_indices', 'to')
        self.sync_value(factory.update_filter_name, 'update_filter', 'from')

        self.auto_size = self.factory.auto_size

        # Initialize the ItemDelegates for each column
        self._update_columns()

    #-------------------------------------------------------------------------
    #  Disposes of the contents of an editor:
    #-------------------------------------------------------------------------

    def dispose(self):
        """ Disposes of the contents of an editor."""

        # Make sure that the auxillary UIs are properly disposed
        if self.toolbar_ui is not None:
            self.toolbar_ui.dispose()
        if self._ui is not None:
            self._ui.dispose()

        # Remove listener for 'items' changes on object trait
        self.context_object.on_trait_change(
            self.update_editor, self.extended_name + '_items', remove=True)

        # Remove listener for changes to traits on the objects in the list
        self.context_object.on_trait_change(
            self.refresh_editor, self.extended_name + '.-', remove=True)

        # Remove listeners for column definition changes
        self.on_trait_change(self._update_columns, 'columns', remove=True)
        self.on_trait_change(
            self._update_columns,
            'columns_items',
            remove=True)

        super(TableEditor, self).dispose()

    #-------------------------------------------------------------------------
    #  Updates the editor when the object trait changes external to the editor:
    #-------------------------------------------------------------------------

    def update_editor(self):
        """Updates the editor when the object trait changes externally to the
        editor."""

        if self._no_notify:
            return

        self.table_view.setUpdatesEnabled(False)
        try:
            filtering = len(
                self.factory.filters) > 0 or self.filter is not None
            if filtering:
                self._update_filtering()

            # invalidate the model, but do not reset it. Resetting the model
            # may cause problems if the selection sync'ed traits are being used
            # externally to manage the selections
            self.model.invalidate()

            self.table_view.resizeColumnsToContents()
            if self.auto_size:
                self.table_view.resizeRowsToContents()

        finally:
            self.table_view.setUpdatesEnabled(True)

    def restore_prefs(self, prefs):
        """ Restores any saved user preference information associated with the
            editor.
        """
        header = self.table_view.horizontalHeader()
        if header is not None and 'column_state' in prefs:
            header.restoreState(prefs['column_state'])

    def save_prefs(self):
        """ Returns any user preference information associated with the editor.
        """
        prefs = {}
        header = self.table_view.horizontalHeader()
        if header is not None:
            prefs['column_state'] = str(header.saveState())
        return prefs

    #-------------------------------------------------------------------------
    #  Requests that the underlying table widget to redraw itself:
    #-------------------------------------------------------------------------

    def refresh_editor(self):
        """Requests that the underlying table widget to redraw itself."""

        self.table_view.viewport().update()

    #-------------------------------------------------------------------------
    #  Creates a new row object using the provided factory:
    #-------------------------------------------------------------------------

    def create_new_row(self):
        """Creates a new row object using the provided factory."""

        factory = self.factory
        kw = factory.row_factory_kw.copy()
        if '__table_editor__' in kw:
            kw['__table_editor__'] = self

        return self.ui.evaluate(factory.row_factory,
                                *factory.row_factory_args, **kw)

    #-------------------------------------------------------------------------
    #  Returns the raw list of model objects:
    #-------------------------------------------------------------------------

    def items(self):
        """Returns the raw list of model objects."""

        items = self.value
        if not isinstance(items, SequenceTypes):
            items = [items]

        if self.factory and self.factory.reverse:
            items = ReversedList(items)

        return items

    #-------------------------------------------------------------------------
    #  Perform actions without notifying the underlying table view or model:
    #-------------------------------------------------------------------------

    def callx(self, func, *args, **kw):
        """Call a function without notifying the underlying table view or
        model."""

        old = self._no_notify
        self._no_notify = True
        try:
            func(*args, **kw)
        finally:
            self._no_notify = old

    def setx(self, **keywords):
        """Set one or more attributes without notifying the underlying table
        view or model."""

        old = self._no_notify
        self._no_notify = True
        try:
            for name, value in keywords.items():
                setattr(self, name, value)
        finally:
            self._no_notify = old

    #-------------------------------------------------------------------------
    #  Sets the current selection to a set of specified objects:
    #-------------------------------------------------------------------------

    def set_selection(self, objects=[], notify=True):
        """Sets the current selection to a set of specified objects."""

        if not isinstance(objects, SequenceTypes):
            objects = [objects]

        mode = self.factory.selection_mode
        indexes = []
        flags = QtGui.QItemSelectionModel.ClearAndSelect

        # In the case of row or column selection, we need a dummy value for the
        # other dimension that has not been filtered.
        source_index = self.model.mapToSource(self.model.index(0, 0))
        source_row, source_column = source_index.row(), source_index.column()

        # Selection mode is 'row' or 'rows'
        if mode.startswith('row'):
            flags |= QtGui.QItemSelectionModel.Rows
            items = self.items()
            for obj in objects:
                try:
                    row = items.index(obj)
                except ValueError:
                    continue
                indexes.append(self.source_model.index(row, source_column))

        # Selection mode is 'column' or 'columns'
        elif mode.startswith('column'):
            flags |= QtGui.QItemSelectionModel.Columns
            for name in objects:
                column = self._column_index_from_name(name)
                if column != -1:
                    indexes.append(self.source_model.index(source_row, column))

        # Selection mode is 'cell' or 'cells'
        else:
            items = self.items()
            for obj, name in objects:
                try:
                    row = items.index(obj)
                except ValueError:
                    continue
                column = self._column_index_from_name(name)
                if column != -1:
                    indexes.append(self.source_model.index(row, column))

        # Perform the selection so that only one signal is emitted
        selection = QtGui.QItemSelection()
        for index in indexes:
            index = self.model.mapFromSource(index)
            if index.isValid():
                self.table_view.setCurrentIndex(index)
                selection.select(index, index)
        smodel = self.table_view.selectionModel()
        try:
            smodel.blockSignals(not notify)
            if len(selection.indexes()):
                smodel.select(selection, flags)
            else:
                smodel.clear()
        finally:
            smodel.blockSignals(False)

    #-------------------------------------------------------------------------
    #  Private methods:
    #-------------------------------------------------------------------------

    def _column_index_from_name(self, name):
        """Returns the index of the column with the given name or -1 if no
        column exists with that name."""

        for i, column in enumerate(self.columns):
            if name == column.name:
                return i
        return -1

    def _customize_filters(self, filter):
        """Allows the user to customize the current set of table filters."""

        filter_editor = TableFilterEditor(editor=self)
        ui = filter_editor.edit_traits(parent=self.control)
        if ui.result:
            self.factory.filters = filter_editor.templates
            self.filter = filter_editor.selected_filter
        else:
            self.setx(filter=filter)

    def _update_filtering(self):
        """Update the filter summary and the filtered indices."""

        items = self.items()
        num_items = len(items)

        f = self.filter
        if f is None:
            self._filtered_cache = None
            self.filtered_indices = range(num_items)
            self.filter_summary = 'All %i items' % num_items
        else:
            if not callable(f):
                f = f.filter
            self._filtered_cache = fc = [f(item) for item in items]
            self.filtered_indices = fi = [i for i, ok in enumerate(fc) if ok]
            self.filter_summary = '%i of %i items' % (len(fi), num_items)

    def _add_image(self, image_resource):
        """ Adds a new image to the image map.
        """
        image = image_resource.create_icon()

        self.image_resources[image_resource] = image
        self.images[image_resource.name] = image

        return image

    def _get_image(self, image):
        """ Converts a user specified image to a QIcon.
        """
        if isinstance(image, basestring):
            self.image = image
            image = self.image

        if isinstance(image, ImageResource):
            result = self.image_resources.get(image)
            if result is not None:
                return result
            return self._add_image(image)

        return self.images.get(image)

    #-- Trait Property getters/setters ---------------------------------------

    @cached_property
    def _get_selected_row(self):
        """Gets the selected row, or the first row if multiple rows are
        selected."""

        mode = self.factory.selection_mode

        if mode.startswith('column'):
            return None
        elif mode == 'row':
            return self.selected

        try:
            if mode == 'rows':
                return self.selected[0]
            elif mode == 'cell':
                return self.selected[0]
            elif mode == 'cells':
                return self.selected[0][0]
        except IndexError:
            return None

    @cached_property
    def _get_selected_indices(self):
        """Gets the row,column indices which match the selected trait"""
        selection_items = self.table_view.selectionModel().selection()
        indices = self.model.mapSelectionToSource(selection_items).indexes()
        return [(index.row(), index.column()) for index in indices]

    def _set_selected_indices(self, indices):
        selected = []
        for row, col in indices:
            selected.append((self.value[row], self.columns[col].name))

        self.selected = selected
        self.set_selection(self.selected, False)
        return

    #-- Trait Change Handlers ------------------------------------------------

    def _filter_changed(self, old_filter, new_filter):
        """Handles the current filter being changed."""

        if not self._no_notify:
            if new_filter is customize_filter:
                do_later(self._customize_filters, old_filter)
            else:
                self._update_filtering()
                self.model.invalidate()
                self.set_selection(self.selected)

    def _update_columns(self):
        """Handle the column list being changed."""

        self.table_view.setItemDelegate(TableDelegate(self.table_view))
        for i, column in enumerate(self.columns):
            if column.renderer:
                self.table_view.setItemDelegateForColumn(i, column.renderer)

        self.model.reset()
        self.table_view.resizeColumnsToContents()
        if self.auto_size:
            self.table_view.resizeRowsToContents()

    def _selected_changed(self, new):
        """Handle the selected row/column/cell being changed externally."""
        if not self._no_notify:
            self.set_selection(self.selected, notify=False)

    def _update_filter_changed(self):
        """ The filter has changed internally.
        """
        self._filter_changed(self.filter, self.filter)

    #-- Event Handlers -------------------------------------------------------

    def _on_row_selection(self, added, removed):
        """Handle the row selection being changed."""

        items = self.items()
        indexes = self.table_view.selectionModel().selectedRows()
        if len(indexes):
            index = self.model.mapToSource(indexes[0])
            selected = items[index.row()]
        else:
            selected = None

        self.setx(selected=selected)
        self.ui.evaluate(self.factory.on_select, self.selected)

    def _on_rows_selection(self, added, removed):
        """Handle the rows selection being changed."""

        items = self.items()
        indexes = self.table_view.selectionModel().selectedRows()
        selected = [items[self.model.mapToSource(index).row()]
                    for index in indexes]

        self.setx(selected=selected)
        self.ui.evaluate(self.factory.on_select, self.selected)

    def _on_column_selection(self, added, removed):
        """Handle the column selection being changed."""

        indexes = self.table_view.selectionModel().selectedColumns()
        if len(indexes):
            index = self.model.mapToSource(indexes[0])
            selected = self.columns[index.column()].name
        else:
            selected = ''

        self.setx(selected=selected)
        self.ui.evaluate(self.factory.on_select, self.selected)

    def _on_columns_selection(self, added, removed):
        """Handle the columns selection being changed."""

        indexes = self.table_view.selectionModel().selectedColumns()
        selected = [self.columns[self.model.mapToSource(index).column()].name
                    for index in indexes]

        self.setx(selected=selected)
        self.ui.evaluate(self.factory.on_select, self.selected)

    def _on_cell_selection(self, added, removed):
        """Handle the cell selection being changed."""

        items = self.items()
        indexes = self.table_view.selectionModel().selectedIndexes()
        if len(indexes):
            index = self.model.mapToSource(indexes[0])
            obj = items[index.row()]
            column_name = self.columns[index.column()].name
        else:
            obj = None
            column_name = ''
        selected = (obj, column_name)

        self.setx(selected=selected)
        self.ui.evaluate(self.factory.on_select, self.selected)

    def _on_cells_selection(self, added, removed):
        """Handle the cells selection being changed."""

        items = self.items()
        indexes = self.table_view.selectionModel().selectedIndexes()
        selected = []
        for index in indexes:
            index = self.model.mapToSource(index)
            obj = items[index.row()]
            column_name = self.columns[index.column()].name
            selected.append((obj, column_name))

        self.setx(selected=selected)
        self.ui.evaluate(self.factory.on_select, self.selected)

    def _on_click(self, index):
        """Handle a cell being clicked."""

        index = self.model.mapToSource(index)
        column = self.columns[index.column()]
        obj = self.items()[index.row()]

        # Fire the same event on the editor after mapping it to a model object
        # and column name:
        self.click = (obj, column)

        # Invoke the column's click handler:
        column.on_click(obj)

    def _on_dclick(self, index):
        """Handle a cell being double clicked."""

        index = self.model.mapToSource(index)
        column = self.columns[index.column()]
        obj = self.items()[index.row()]

        # Fire the same event on the editor after mapping it to a model object
        # and column name:
        self.dclick = (obj, column)

        # Invoke the column's double-click handler:
        column.on_dclick(obj)

    def _on_context_insert(self):
        """Handle 'insert item' being selected from the header context menu."""

        self.model.insertRow(self.header_row)

    def _on_context_append(self):
        """Handle 'add item' being selected from the empty space context
        menu."""

        self.model.insertRow(self.model.rowCount())

    def _on_context_remove(self):
        """Handle 'remove item' being selected from the header context menu."""

        self.model.removeRow(self.header_row)

    def _on_context_move_up(self):
        """Handle 'move up' being selected from the header context menu."""

        self.model.moveRow(self.header_row, self.header_row - 1)

    def _on_context_move_down(self):
        """Handle 'move down' being selected from the header context menu."""

        self.model.moveRow(self.header_row, self.header_row + 1)
Ejemplo n.º 3
0
class InstSource(HasPrivateTraits):
    """Expose measurement information from a inst file.

    Parameters
    ----------
    file : File
        Path to the BEM file (*.fif).

    Attributes
    ----------
    fid : Array, shape = (3, 3)
        Each row contains the coordinates for one fiducial point, in the order
        Nasion, RAP, LAP. If no file is set all values are 0.
    """

    file = File(exists=True, filter=['*.fif'])

    inst_fname = Property(Str, depends_on='file')
    inst_dir = Property(depends_on='file')
    inst = Property(depends_on='file')

    points_filter = Any(desc="Index to select a subset of the head shape "
                        "points")
    n_omitted = Property(Int, depends_on=['points_filter'])

    # head shape
    inst_points = Property(depends_on='inst',
                           desc="Head shape points in the "
                           "inst file(n x 3 array)")
    points = Property(depends_on=['inst_points', 'points_filter'],
                      desc="Head "
                      "shape points selected by the filter (n x 3 array)")

    # fiducials
    fid_dig = Property(depends_on='inst',
                       desc="Fiducial points "
                       "(list of dict)")
    fid_points = Property(depends_on='fid_dig',
                          desc="Fiducial points {ident: "
                          "point} dict}")
    lpa = Property(depends_on='fid_points',
                   desc="LPA coordinates (1 x 3 "
                   "array)")
    nasion = Property(depends_on='fid_points',
                      desc="Nasion coordinates (1 x "
                      "3 array)")
    rpa = Property(depends_on='fid_points',
                   desc="RPA coordinates (1 x 3 "
                   "array)")

    view = View(
        VGroup(Item('file'),
               Item('inst_fname', show_label=False, style='readonly')))

    @cached_property
    def _get_n_omitted(self):
        if self.points_filter is None:
            return 0
        else:
            return np.sum(self.points_filter == False)  # noqa: E712

    @cached_property
    def _get_inst(self):
        if self.file:
            info = read_info(self.file)
            if info['dig'] is None:
                error(
                    None, "The selected FIFF file does not contain "
                    "digitizer information. Please select a different "
                    "file.", "Error Reading FIFF File")
                self.reset_traits(['file'])
            else:
                return info

    @cached_property
    def _get_inst_dir(self):
        return os.path.dirname(self.file)

    @cached_property
    def _get_inst_fname(self):
        if self.file:
            return os.path.basename(self.file)
        else:
            return '-'

    @cached_property
    def _get_inst_points(self):
        if not self.inst:
            return np.zeros((1, 3))

        points = np.array([
            d['r'] for d in self.inst['dig']
            if d['kind'] == FIFF.FIFFV_POINT_EXTRA
        ])
        return points

    @cached_property
    def _get_points(self):
        if self.points_filter is None:
            return self.inst_points
        else:
            return self.inst_points[self.points_filter]

    @cached_property
    def _get_fid_dig(self):  # noqa: D401
        """Fiducials for info['dig']."""
        if not self.inst:
            return []
        dig = self.inst['dig']
        dig = [d for d in dig if d['kind'] == FIFF.FIFFV_POINT_CARDINAL]
        return dig

    @cached_property
    def _get_fid_points(self):
        if not self.inst:
            return {}
        digs = dict((d['ident'], d) for d in self.fid_dig)
        return digs

    @cached_property
    def _get_nasion(self):
        if self.fid_points:
            return self.fid_points[FIFF.FIFFV_POINT_NASION]['r'][None, :]
        else:
            return np.zeros((1, 3))

    @cached_property
    def _get_lpa(self):
        if self.fid_points:
            return self.fid_points[FIFF.FIFFV_POINT_LPA]['r'][None, :]
        else:
            return np.zeros((1, 3))

    @cached_property
    def _get_rpa(self):
        if self.fid_points:
            return self.fid_points[FIFF.FIFFV_POINT_RPA]['r'][None, :]
        else:
            return np.zeros((1, 3))

    def _file_changed(self):
        self.reset_traits(('points_filter', ))
Ejemplo n.º 4
0
class AdapterManager(HasTraits):
    """ A manager for adapter factories. """

    #### 'AdapterManager' interface ###########################################

    # All registered type-scope factories by the type of object that they
    # adapt.
    #
    # The dictionary is keyed by the *name* of the class rather than the class
    # itself to allow for adapter factory proxies to register themselves
    # without having to load and create the factories themselves (i.e., to
    # allow us to lazily load adapter factories contributed by plugins). This
    # is a slight compromise as it is obviously geared towards use in Envisage,
    # but it doesn't affect the API other than allowing a class OR a string to
    # be passed to 'register_adapters'.
    #
    # { String adaptee_class_name : List(AdapterFactory) factories }
    type_factories = Property(Dict)

    # All registered instance-scope factories by the object that they adapt.
    #
    # { id(obj) : List(AdapterFactory) factories }
    instance_factories = Property(Dict)

    # The type system used by the manager (it determines 'is_a' relationships
    # and type MROs etc). By default we use standard Python semantics.
    type_system = Instance(AbstractTypeSystem, PythonTypeSystem())

    #### Private interface ####################################################

    # All registered type-scope factories by the type of object that they
    # adapt.
    _type_factories = Dict

    # All registered instance-scope factories by the object that they adapt.
    _instance_factories = Dict

    ###########################################################################
    # 'AdapterManager' interface.
    ###########################################################################

    #### Properties ###########################################################

    def _get_type_factories(self):
        """ Returns all registered type-scope factories. """

        return self._type_factories.copy()

    def _get_instance_factories(self):
        """ Returns all registered instance-scope factories. """

        return self._instance_factories.copy()

    #### Methods ##############################################################

    def adapt(self, adaptee, target_class, *args, **kw):
        """ Returns an adapter that adapts an object to the target class.

        'adaptee' is the object that we want to adapt.
        'target_class' is the class that the adaptee should be adapted to.

        Returns None if no such adapter exists.

        """

        # If the object is already an instance of the target class then we
        # simply return it.
        if self.type_system.is_a(adaptee, target_class):
            adapter = adaptee

        # Otherwise, look at each class in the adaptee's MRO to see if there
        # is an adapter factory registered that can adapt the object to the
        # target class.
        else:
            # Look for instance-scope adapters first.
            adapter = self._adapt_instance(adaptee, target_class, *args, **kw)

            # If no instance-scope adapter was found then try type-scope
            # adapters.
            if adapter is None:
                for adaptee_class in self.type_system.get_mro(type(adaptee)):
                    adapter = self._adapt_type(adaptee, adaptee_class,
                                               target_class, *args, **kw)
                    if adapter is not None:
                        break

        return adapter

    def register_instance_adapters(self, factory, obj):
        """ Registers an instance-scope adapter factory.

        A factory can be in exactly one manager (as it uses the manager's type
        system).

        """

        factories = self._instance_factories.setdefault(id(obj), [])
        factories.append(factory)

        # A factory can be in exactly one manager.
        factory.adapter_manager = self

        return

    def unregister_instance_adapters(self, factory, obj):
        """ Unregisters an instance scope adapter factory.

        A factory can be in exactly one manager (as it uses the manager's type
        system).

        """

        factories = self._instance_factories.setdefault(id(obj), [])
        if factory in factories:
            factories.remove(factory)

        # A factory can be in exactly one manager.
        factory.adapter_manager = None

        return

    def register_type_adapters(self, factory, adaptee_class):
        """ Registers a type-scope adapter factory.

        'adaptee_class' can be either a class object or the name of a class.

        A factory can be in exactly one manager (as it uses the manager's type
        system).

        """

        if isinstance(adaptee_class, six.string_types):
            adaptee_class_name = adaptee_class

        else:
            adaptee_class_name = self._get_class_name(adaptee_class)

        factories = self._type_factories.setdefault(adaptee_class_name, [])
        factories.append(factory)

        # A factory can be in exactly one manager.
        factory.adapter_manager = self

        return

    def unregister_type_adapters(self, factory):
        """ Unregisters a type-scope adapter factory. """

        for adaptee_class_name, factories in self._type_factories.items():
            if factory in factories:
                factories.remove(factory)

        # The factory is no longer deemed to be part of this manager.
        factory.adapter_manager = None

        return

    #### DEPRECATED ###########################################################

    def register_adapters(self, factory, adaptee_class):
        """ Registers an adapter factory.

        'adaptee_class' can be either a class object or the name of a class.

        A factory can be in exactly one manager (as it uses the manager's type
        system).

        """

        warnings.warn('Use "register_type_adapters" instead.',
                      DeprecationWarning)

        self.register_type_adapters(factory, adaptee_class)

        return

    def unregister_adapters(self, factory):
        """ Unregisters an adapter factory. """

        warnings.warn('use "unregister_type_adapters" instead.',
                      DeprecationWarning)

        self.unregister_type_adapters(factory)

        return

    ###########################################################################
    # Private interface.
    ###########################################################################

    def _adapt_instance(self, adaptee, target_class, *args, **kw):
        """ Returns an adapter that adaptes an object to the target class.

        Returns None if no such adapter exists.

        """

        for factory in self._instance_factories.get(id(adaptee), []):
            adapter = factory.adapt(adaptee, target_class, *args, **kw)
            if adapter is not None:
                break

        else:
            adapter = None

        return adapter

    def _adapt_type(self, adaptee, adaptee_class, target_class, *args, **kw):
        """ Returns an adapter that adapts an object to the target class.

        Returns None if no such adapter exists.

        """

        adaptee_class_name = self._get_class_name(adaptee_class)

        for factory in self._type_factories.get(adaptee_class_name, []):
            adapter = factory.adapt(adaptee, target_class, *args, **kw)
            if adapter is not None:
                break

        else:
            adapter = None

        return adapter

    def _get_class_name(self, klass):
        """ Returns the full class name for a class. """

        return "%s.%s" % (klass.__module__, klass.__name__)
Ejemplo n.º 5
0
class Vis2D(HasStrictTraits):
    '''Each state and operator object can be associated with 
    several visualization objects with a shortened class name Viz3D. 
    In order to introduce a n independent class subsystem into 
    the class structure, objects supporting visualization inherit 
    from Visual3D which introduces a dictionary viz3d objects.
    '''
    def setup(self):
        pass

    sim = WeakRef
    '''Root of the simulation to extract the data
    '''

    vot = Float(0.0, time_change=True)
    '''Visual object time
    '''

    viz2d_classes = Dict
    '''Visualization classes applicable to this object. 
    '''

    viz2d_class_names = Property(List(Str), depends_on='viz2d_classes')
    '''Keys of the viz2d classes
    '''

    @cached_property
    def _get_viz2d_class_names(self):
        return list(self.viz2d_classes.keys())

    selected_viz2d_class = Str

    def _selected_viz2d_class_default(self):
        if len(self.viz2d_class_names) > 0:
            return self.viz2d_class_names[0]
        else:
            return ''

    add_selected_viz2d = Button(label='Add plot viz2d')

    def _add_selected_viz2d_fired(self):
        viz2d_class_name = self.selected_viz2d_class
        self.add_viz2d(viz2d_class_name, '<unnamed>')

    def add_viz2d(self, class_name, name, **kw):
        if name == '':
            name = class_name
        viz2d_class = self.viz2d_classes[class_name]
        viz2d = viz2d_class(name=name, vis2d=self, **kw)
        self.viz2d.append(viz2d)
        if hasattr(self, 'ui') and self.ui:
            self.ui.viz_sheet.viz2d_list.append(viz2d)

    viz2d = List(Viz2D)

    actions = HGroup(
        UItem('add_selected_viz2d'),
        UItem('selected_viz2d_class',
              springy=True,
              editor=EnumEditor(name='object.viz2d_class_names', )),
    )

    def plt(self, name, label=None):
        return Viz2DPlot(plot_fn=name, label=label, vis2d=self)

    view = View(Include('actions'), resizable=True)
Ejemplo n.º 6
0
class StatisticViewHandlerMixin(HasTraits):

    numeric_indices = Property(depends_on="model.statistic, model.subset")
    indices = Property(depends_on="model.statistic, model.subset")
    levels = Property(depends_on="model.statistic")

    # MAGIC: gets the value for the property numeric_indices
    def _get_numeric_indices(self):
        context = self.info.ui.context['context']

        if not (context and context.statistics and self.model
                and self.model.statistic[0]):
            return []

        stat = context.statistics[self.model.statistic]
        data = pd.DataFrame(index=stat.index)

        if self.model.subset:
            data = data.query(self.model.subset)

        if len(data) == 0:
            return []

        names = list(data.index.names)
        for name in names:
            unique_values = data.index.get_level_values(name).unique()
            if len(unique_values) == 1:
                data.index = data.index.droplevel(name)

        data.reset_index(inplace=True)
        return [x for x in data if util.is_numeric(data[x])]

    # MAGIC: gets the value for the property indices
    def _get_indices(self):
        context = self.info.ui.context['context']

        if not (context and context.statistics and self.model
                and self.model.statistic[0]):
            return []

        stat = context.statistics[self.model.statistic]
        data = pd.DataFrame(index=stat.index)

        if self.model.subset:
            data = data.query(self.model.subset)

        if len(data) == 0:
            return []

        names = list(data.index.names)
        for name in names:
            unique_values = data.index.get_level_values(name).unique()
            if len(unique_values) == 1:
                data.index = data.index.droplevel(name)

        return list(data.index.names)

    # MAGIC: gets the value for the property 'levels'
    # returns a Dict(Str, pd.Series)

    def _get_levels(self):
        context = self.info.ui.context['context']

        if not (context and context.statistics and self.model
                and self.model.statistic[0]):
            return []

        stat = context.statistics[self.model.statistic]
        index = stat.index

        names = list(index.names)
        for name in names:
            unique_values = index.get_level_values(name).unique()
            if len(unique_values) == 1:
                index = index.droplevel(name)

        names = list(index.names)
        ret = {}
        for name in names:
            ret[name] = pd.Series(index.get_level_values(name)).sort_values()
            ret[name] = pd.Series(ret[name].unique())

        return ret
class MassBalanceAnalyzer(HasStrictTraits):
    """ Tool to compare loaded mass from continuous data to the method/solution
    information.
    """
    #: Target information we are evaluating mass balance for
    target_experiment = Instance(Experiment)

    #: Time before which to truncate experiment data
    time_of_origin = Instance(UnitScalar)

    #: Continuous UV data to evaluate loaded mass from (passed separately to
    #: support doing analysis while building an exp)
    continuous_data = Instance(XYData)

    #: Loaded mass, as computed from method information
    mass_from_method = Instance(UnitScalar)

    #: Loaded mass, as computed from UV continuous data
    mass_from_uv = Instance(UnitScalar)

    #: Current load product concentration in target experiment
    current_concentration = Property(Instance(UnitScalar))

    #: Current load step volume in target experiment
    current_volume = Property(Instance(UnitScalar))

    #: Threshold for relative difference to define the data as imbalanced
    balance_threshold = Float

    #: Whether the target experiment's data is balanced (in agreement)
    balanced = Bool

    # View elements -----------------------------------------------------------

    #: Plot of the UV data
    plot = Instance(Plot)

    #: Data container for the plot
    plot_data = Instance(ArrayPlotData)

    # Traits methods ----------------------------------------------------------

    def __init__(self, **traits):
        if "target_experiment" not in traits:
            msg = "A target experiment is required to create a {}"
            msg = msg.format(self.__class__.__name__)
            logger.exception(msg)
            raise ValueError(msg)

        super(MassBalanceAnalyzer, self).__init__(**traits)
        use_exp_data = (self.continuous_data is None
                        and self.target_experiment.output is not None)

        # If the tool is created for a fully created experiment, extract the
        # continuous data from the experiment,
        if use_exp_data:
            data_dict = self.target_experiment.output.continuous_data
            self.continuous_data = data_dict[UV_DATA_KEY]

    def traits_view(self):
        view = KromView(
            Item("plot", editor=ComponentEditor(), show_label=False))
        return view

    # Public interface --------------------------------------------------------

    def analyze(self):
        """ Analyze all available loads, and compare to experiments cont. data.

        Returns
        -------
        list
            Returns list of load names that are not balanced.
        """
        product = self.target_experiment.product
        if len(product.product_components) > 1:
            msg = "Unable to analyze the loaded mass for multi-component " \
                  "products, since we don't have a product extinction coeff."
            logger.exception(msg)
            raise ValueError(msg)

        msg = "Analyzing mass balance for {}".format(
            self.target_experiment.name)
        logger.debug(msg)

        self.mass_from_method = self.loaded_mass_from_method()
        self.mass_from_uv = self.loaded_mass_from_uv()

        if self.mass_from_uv is None:
            self.balanced = True
            return

        diff = abs(self.mass_from_method - self.mass_from_uv)
        rel_diff = diff / self.mass_from_method
        msg = "Loaded mass computed from method and UV are different by " \
              "{:.3g}%".format(float(rel_diff)*100)
        logger.debug(msg)

        self.balanced = float(rel_diff) <= self.balance_threshold

    def loaded_mass_from_method(self):
        """ Returns loaded product mass in grams, as compute from method data.
        """
        vol = self.current_volume
        if units_almost_equal(vol, column_volumes):
            vol = float(vol) * self.target_experiment.column.volume
        elif has_volume_units(vol):
            pass
        else:
            msg = "Unexpected unit for the load step volume: {}"
            msg = msg.format(vol.units.label)
            logger.exception(msg)
            raise NotImplementedError(msg)

        mass = vol * self.current_concentration
        return convert_units(mass, tgt_unit="gram")

    def loaded_mass_from_uv(self):
        """ Returns loaded product mass in grams, as compute from UV data.
        """
        if not self.continuous_data:
            return

        data = self.continuous_data
        product = self.target_experiment.product
        ext_coeff = product.product_components[0].extinction_coefficient
        method_step_times = self.target_experiment.method_step_boundary_times
        t_stop = UnitScalar(method_step_times[-1],
                            units=method_step_times.units)
        mass = compute_mass_from_abs_data(data,
                                          ext_coeff=ext_coeff,
                                          experim=self.target_experiment,
                                          t_start=self.time_of_origin,
                                          t_stop=t_stop)
        return convert_units(mass, tgt_unit="gram")

    # Adjustment methods ------------------------------------------------------

    def compute_loaded_vol(self, tgt_units=column_volumes):
        """ Compute the load step volume that would match the UV data at
        constant load solution concentration.

        Returns
        -------
        UnitScalar
            Load step volume, in CV, that would be needed to match the UV data.
        """
        if tgt_units not in [column_volumes, g_per_liter_resin]:
            msg = "Supported target units are CV and g/Liter of resin but " \
                  "got {}.".format(tgt_units.label)
            logger.debug(msg)
            raise ValueError(msg)

        target_mass = self.mass_from_uv
        col_volume = self.target_experiment.column.volume
        concentration = self.current_concentration
        vol = target_mass / concentration / col_volume
        # Test equality on the labels since CV and g_per_liter_resin are equal
        # from a derivation point of view (dimensionless)
        if tgt_units.label == g_per_liter_resin.label:
            vol = float(vol * concentration)
            vol = UnitScalar(vol, units=g_per_liter_resin)
        else:
            vol = UnitScalar(vol, units=column_volumes)

        return vol

    def compute_concentration(self):
        """ Compute the load solution concentration that would match the UV
        data at constant load step volume.

        Returns
        -------
        UnitScalar
            Load solution concentration, in g/L, that would be needed to match
            the UV data.
        """
        target_mass = self.mass_from_uv
        vol = self.current_volume
        if units_almost_equal(vol, column_volumes):
            vol = float(vol) * self.target_experiment.column.volume

        concentration = target_mass / vol
        return convert_units(concentration, tgt_unit="g/L")

    # Traits property getters/setters -----------------------------------------

    def _get_current_volume(self):
        load_step = self.target_experiment.method.load
        return load_step.volume

    def _get_current_concentration(self):
        load_sol = self.target_experiment.method.load.solutions[0]
        comp_concs = load_sol.product_component_concentrations
        return unitarray_to_unitted_list(comp_concs)[0]

    # Traits listeners --------------------------------------------------------

    def _continuous_data_changed(self):
        if self.plot_data is None:
            return

        self.plot_data.update_data(x=self.continuous_data.x_data,
                                   y=self.continuous_data.y_data)

    # Traits initialization methods -------------------------------------------

    def _balance_threshold_default(self):
        from kromatography.utils.app_utils import get_preferences
        prefs = get_preferences()
        return prefs.file_preferences.exp_importer_mass_threshold

    def _plot_default(self):
        self.plot_data = ArrayPlotData(x=self.continuous_data.x_data,
                                       y=self.continuous_data.y_data)
        plot = Plot(self.plot_data)
        plot.plot(("x", "y"))
        x_units = self.continuous_data.x_metadata["units"]
        y_units = self.continuous_data.y_metadata["units"]
        plot.index_axis.title = "Time ({})".format(x_units)
        plot.value_axis.title = "UV Absorption ({})".format(y_units)
        # Add zoom and pan tools to the plot
        zoom = BetterSelectingZoom(component=plot,
                                   tool_mode="box",
                                   always_on=False)
        plot.overlays.append(zoom)
        plot.tools.append(PanTool(component=plot))
        return plot
Ejemplo n.º 8
0
class VUMeter(Component):

    # Value expressed in dB
    db = Property(Float)

    # Value expressed as a percent.
    percent = Range(low=0.0)

    # The maximum value to be display in the VU Meter, expressed as a percent.
    max_percent = Float(150.0)

    # Angle (in degrees) from a horizontal line through the hinge of the
    # needle to the edge of the meter axis.
    angle = Float(45.0)

    # Values of the percentage-based ticks; these are drawn and labeled along
    # the bottom of the curve axis.
    percent_ticks = List(list(range(0, 101, 20)))

    # Text to write in the middle of the VU Meter.
    text = Str("VU")

    # Font used to draw `text`.
    text_font = KivaFont("modern 48")

    # Font for the db tick labels.
    db_tick_font = KivaFont("modern 16")

    # Font for the percent tick labels.
    percent_tick_font = KivaFont("modern 12")

    # beta is the fraction of the of needle that is "hidden".
    # beta == 0 puts the hinge point of the needle on the bottom
    # edge of the window.  Values that result in a decent looking
    # meter are 0 < beta < .65.
    # XXX needs a better name!
    _beta = Float(0.3)

    # _outer_radial_margin is the radial extent beyond the circular axis
    # to include  in calculations of the space required for the meter.
    # This allows room for the ticks and labels.
    _outer_radial_margin = Float(60.0)

    # The angle (in radians) of the span of the curve axis.
    _phi = Property(Float, depends_on=['angle'])

    # This is the radius of the circular axis (in screen coordinates).
    _axis_radius = Property(Float, depends_on=['_phi', 'width', 'height'])

    #---------------------------------------------------------------------
    # Trait Property methods
    #---------------------------------------------------------------------

    def _get_db(self):
        db = percent_to_db(self.percent)
        return db

    def _set_db(self, value):
        self.percent = db_to_percent(value)

    def _get__phi(self):
        phi = math.pi * (180.0 - 2 * self.angle) / 180.0
        return phi

    def _get__axis_radius(self):
        M = self._outer_radial_margin
        beta = self._beta
        w = self.width
        h = self.height
        phi = self._phi

        R1 = w / (2 * math.sin(phi / 2)) - M
        R2 = (h - M) / (1 - beta * math.cos(phi / 2))
        R = min(R1, R2)
        return R

    #---------------------------------------------------------------------
    # Trait change handlers
    #---------------------------------------------------------------------

    def _anytrait_changed(self):
        self.request_redraw()

    #---------------------------------------------------------------------
    # Component API
    #---------------------------------------------------------------------

    def _draw_mainlayer(self, gc, view_bounds=None, mode="default"):

        beta = self._beta
        phi = self._phi

        w = self.width

        M = self._outer_radial_margin
        R = self._axis_radius

        # (ox, oy) is the position of the "hinge point" of the needle
        # (i.e. the center of rotation).  For beta > ~0, oy is negative,
        # so this point is below the visible region.
        ox = self.x + self.width // 2
        oy = -beta * R * math.cos(phi / 2) + 1

        left_theta = math.radians(180 - self.angle)
        right_theta = math.radians(self.angle)

        # The angle of the 100% position.
        nominal_theta = self._percent_to_theta(100.0)

        # The color of the axis for percent > 100.
        red = (0.8, 0, 0)

        with gc:
            gc.set_antialias(True)

            # Draw everything relative to the center of the circles.
            gc.translate_ctm(ox, oy)

            # Draw the primary ticks and tick labels on the curved axis.
            gc.set_fill_color((0, 0, 0))
            gc.set_font(self.db_tick_font)
            for db in [-20, -10, -7, -5, -3, -2, -1, 0, 1, 2, 3]:
                db_percent = db_to_percent(db)
                theta = self._percent_to_theta(db_percent)
                x1 = R * math.cos(theta)
                y1 = R * math.sin(theta)
                x2 = (R + 0.3 * M) * math.cos(theta)
                y2 = (R + 0.3 * M) * math.sin(theta)
                gc.set_line_width(2.5)
                gc.move_to(x1, y1)
                gc.line_to(x2, y2)
                gc.stroke_path()

                text = str(db)
                if db > 0:
                    text = '+' + text
                self._draw_rotated_label(gc, text, theta, R + 0.4 * M)

            # Draw the secondary ticks on the curve axis.
            for db in [-15, -9, -8, -6, -4, -0.5, 0.5]:
                ##db_percent = 100 * math.pow(10.0, db / 20.0)
                db_percent = db_to_percent(db)
                theta = self._percent_to_theta(db_percent)
                x1 = R * math.cos(theta)
                y1 = R * math.sin(theta)
                x2 = (R + 0.2 * M) * math.cos(theta)
                y2 = (R + 0.2 * M) * math.sin(theta)
                gc.set_line_width(1.0)
                gc.move_to(x1, y1)
                gc.line_to(x2, y2)
                gc.stroke_path()

            # Draw the percent ticks and label on the bottom of the
            # curved axis.
            gc.set_font(self.percent_tick_font)
            gc.set_fill_color((0.5, 0.5, 0.5))
            gc.set_stroke_color((0.5, 0.5, 0.5))
            percents = self.percent_ticks
            for tick_percent in percents:
                theta = self._percent_to_theta(tick_percent)
                x1 = (R - 0.15 * M) * math.cos(theta)
                y1 = (R - 0.15 * M) * math.sin(theta)
                x2 = R * math.cos(theta)
                y2 = R * math.sin(theta)
                gc.set_line_width(2.0)
                gc.move_to(x1, y1)
                gc.line_to(x2, y2)
                gc.stroke_path()

                text = str(tick_percent)
                if tick_percent == percents[-1]:
                    text = text + "%"
                self._draw_rotated_label(gc, text, theta, R - 0.3 * M)

            if self.text:
                gc.set_font(self.text_font)
                tx, ty, tw, th = gc.get_text_extent(self.text)
                gc.set_fill_color((0, 0, 0, 0.25))
                gc.set_text_matrix(affine.affine_from_rotation(0))
                gc.set_text_position(-0.5 * tw, (0.75 * beta + 0.25) * R)
                gc.show_text(self.text)

            # Draw the red curved axis.
            gc.set_stroke_color(red)
            w = 10
            gc.set_line_width(w)
            gc.arc(0, 0, R + 0.5 * w - 1, right_theta, nominal_theta)
            gc.stroke_path()

            # Draw the black curved axis.
            w = 4
            gc.set_line_width(w)
            gc.set_stroke_color((0, 0, 0))
            gc.arc(0, 0, R + 0.5 * w - 1, nominal_theta, left_theta)
            gc.stroke_path()

            # Draw the filled arc at the bottom.
            gc.set_line_width(2)
            gc.set_stroke_color((0, 0, 0))
            gc.arc(0, 0, beta * R, math.radians(self.angle),
                   math.radians(180 - self.angle))
            gc.stroke_path()
            gc.set_fill_color((0, 0, 0, 0.25))
            gc.arc(0, 0, beta * R, math.radians(self.angle),
                   math.radians(180 - self.angle))
            gc.fill_path()

            # Draw the needle.
            percent = self.percent
            # If percent exceeds max_percent, the needle is drawn at max_percent.
            if percent > self.max_percent:
                percent = self.max_percent
            needle_theta = self._percent_to_theta(percent)
            gc.rotate_ctm(needle_theta - 0.5 * math.pi)
            self._draw_vertical_needle(gc)

    #---------------------------------------------------------------------
    # Private methods
    #---------------------------------------------------------------------

    def _draw_vertical_needle(self, gc):
        """ Draw the needle of the meter, pointing straight up. """
        beta = self._beta
        R = self._axis_radius
        end_y = beta * R
        blob_y = R - 0.6 * self._outer_radial_margin
        tip_y = R + 0.2 * self._outer_radial_margin
        lw = 5

        with gc:
            gc.set_alpha(1)
            gc.set_fill_color((0, 0, 0))

            # Draw the needle from the bottom to the blob.
            gc.set_line_width(lw)
            gc.move_to(0, end_y)
            gc.line_to(0, blob_y)
            gc.stroke_path()

            # Draw the thin part of the needle from the blob to the tip.
            gc.move_to(lw, blob_y)
            control_y = blob_y + 0.25 * (tip_y - blob_y)
            gc.quad_curve_to(0.2 * lw, control_y, 0, tip_y)
            gc.quad_curve_to(-0.2 * lw, control_y, -lw, blob_y)
            gc.line_to(lw, blob_y)
            gc.fill_path()

            # Draw the blob on the needle.
            gc.arc(0, blob_y, 6.0, 0, 2 * math.pi)
            gc.fill_path()

    def _draw_rotated_label(self, gc, text, theta, radius):

        tx, ty, tw, th = gc.get_text_extent(text)

        rr = math.sqrt(radius**2 + (0.5 * tw)**2)
        dtheta = math.atan2(0.5 * tw, radius)
        text_theta = theta + dtheta
        x = rr * math.cos(text_theta)
        y = rr * math.sin(text_theta)

        rot_theta = theta - 0.5 * math.pi
        with gc:
            gc.set_text_matrix(affine.affine_from_rotation(rot_theta))
            gc.set_text_position(x, y)
            gc.show_text(text)

    def _percent_to_theta(self, percent):
        """ Convert percent to the angle theta, in radians.

        theta is the angle of the needle measured counterclockwise from
        the horizontal (i.e. the traditional angle of polar coordinates).
        """
        angle = (self.angle + (180.0 - 2 * self.angle) *
                 (self.max_percent - percent) / self.max_percent)
        theta = math.radians(angle)
        return theta

    def _db_to_theta(self, db):
        """ Convert db to the angle theta, in radians. """
        percent = db_to_percent(db)
        theta = self._percent_to_theta(percent)
        return theta
Ejemplo n.º 9
0
class ImageActor(Module):

    # An image actor.
    actor = Instance(tvtk.ImageActor, allow_none=False, record=True)

    input_info = PipelineInfo(datasets=['image_data'],
                              attribute_types=['any'],
                              attributes=['any'])

    # An ImageMapToColors TVTK filter to adapt datasets without color
    # information
    image_map_to_color = Instance(tvtk.ImageMapToColors, (),
                                  allow_none=False,
                                  record=True)

    map_scalars_to_color = Bool

    _force_map_scalars_to_color = Property(depends_on='module_manager.source')

    ########################################
    # The view of this module.

    view = View(Group(Item(name='actor', style='custom', resizable=True),
                      show_labels=False,
                      label='Actor'),
                Group(
                    Group(
                        Item('map_scalars_to_color',
                             enabled_when='not _force_map_scalars_to_color')),
                    Item('image_map_to_color',
                         style='custom',
                         enabled_when='map_scalars_to_color',
                         show_label=False),
                    label='Map Scalars',
                ),
                width=500,
                height=600,
                resizable=True)

    ######################################################################
    # `Module` interface
    ######################################################################
    def setup_pipeline(self):
        self.actor = tvtk.ImageActor()

    @on_trait_change('map_scalars_to_color,'
                     'image_map_to_color.[output_format,pass_alpha_to_output],'
                     'module_manager.scalar_lut_manager.lut_mode,'
                     'module_manager.vector_lut_manager.lut_mode')
    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.
        """
        mm = self.module_manager
        if mm is None:
            return
        src = mm.source
        if self._force_map_scalars_to_color:
            self.set(map_scalars_to_color=True, trait_change_notify=False)
        if self.map_scalars_to_color:
            self.configure_connection(self.image_map_to_color, src)
            self.image_map_to_color.lookup_table = mm.scalar_lut_manager.lut
            self.image_map_to_color.update()
            self.configure_input_data(self.actor,
                                      self.image_map_to_color.output)
        else:
            self.configure_input_data(self.actor, src.outputs[0])
        self.pipeline_changed = True

    def update_data(self):
        """Override this method so that it flushes the vtk pipeline if
        that is necessary.
        """
        # Just set data_changed, the component should do the rest.
        self.data_changed = True

    ######################################################################
    # Non-public methods.
    ######################################################################
    def _actor_changed(self, old, new):
        if old is not None:
            self.actors.remove(old)
            old.on_trait_change(self.render, remove=True)
        self.actors.append(new)
        new.on_trait_change(self.render)

    def _get__force_map_scalars_to_color(self):
        mm = self.module_manager
        if mm is None:
            return False
        src = mm.source
        return not isinstance(src.outputs[0].point_data.scalars,
                              tvtk.UnsignedCharArray)
class MATSXDMicroplaneDamageFatigueJir(MATSEvalMicroplaneFatigue):

    '''
    Microplane Damage Fatigue Model.
    '''

    #-------------------------------------------------------------------------
    # Setup for computation within a supplied spatial context
    #-------------------------------------------------------------------------
    D4_e = Property

    def _get_D4_e(self):
        # Return the elasticity tensor
        return self.elasticity_tensors

    #-------------------------------------------------------------------------
    # MICROPLANE-Kinematic constraints
    #-------------------------------------------------------------------------

    # get the dyadic product of the microplane normals
    _MPNN = Property(depends_on='n_mp')

    @cached_property
    def _get__MPNN(self):
        # dyadic product of the microplane normals

        MPNN_nij = einsum('ni,nj->nij', self._MPN, self._MPN)
        return MPNN_nij

    # get the third order tangential tensor (operator) for each microplane
    _MPTT = Property(depends_on='n_mp')

    @cached_property
    def _get__MPTT(self):
        # Third order tangential tensor for each microplane
        delta = identity(3)
        MPTT_nijr = 0.5 * (einsum('ni,jr -> nijr', self._MPN, delta) +
                           einsum('nj,ir -> njir', self._MPN, delta) - 2 *
                           einsum('ni,nj,nr -> nijr', self._MPN, self._MPN, self._MPN))
        return MPTT_nijr

    def _get_e_vct_arr(self, eps_eng):
        # Projection of apparent strain onto the individual microplanes
        e_ni = einsum('nj,ji->ni', self._MPN, eps_eng)
        return e_ni

    def _get_e_N_arr(self, e_vct_arr):
        # get the normal strain array for each microplane
        eN_n = einsum('ni,ni->n', e_vct_arr, self._MPN)
        return eN_n

    def _get_e_T_vct_arr(self, e_vct_arr):
        # get the tangential strain vector array for each microplane
        eN_n = self._get_e_N_arr(e_vct_arr)
        eN_vct_ni = einsum('n,ni->ni', eN_n, self._MPN)
        return e_vct_arr - eN_vct_ni

    #-------------------------------------------------
    # Alternative methods for the kinematic constraint
    #-------------------------------------------------
    def _get_e_N_arr_2(self, eps_eng):
        return einsum('nij,ij->n', self._MPNN, eps_eng)

    def _get_e_T_vct_arr_2(self, eps_eng):
        MPTT_ijr = self._get__MPTT()
        return einsum('nijr,ij->nr', MPTT_ijr, eps_eng)

    def _get_e_vct_arr_2(self, eps_eng):
        return self._e_N_arr_2 * self._MPN + self._e_t_vct_arr_2

    #--------------------------------------------------------
    # return the state variables (Damage , inelastic strains)
    #--------------------------------------------------------
    def _get_state_variables(self, sctx, eps_app_eng, sigma_kk):

        e_N_arr = self._get_e_N_arr_2(eps_app_eng)
        e_T_vct_arr = self._get_e_T_vct_arr_2(eps_app_eng)

        sctx_arr = zeros_like(sctx)

        sctx_N = self.get_normal_Law(e_N_arr, sctx)
        sctx_arr[:, 0:5] = sctx_N

        sctx_tangential = self.get_tangential_Law(e_T_vct_arr, sctx, sigma_kk)
        sctx_arr[:, 5:14] = sctx_tangential

        return sctx_arr

    #-----------------------------------------------------------------
    # Returns a list of the plastic normal strain  for all microplanes.
    #-----------------------------------------------------------------
    def _get_eps_N_p_arr(self, sctx, eps_app_eng, sigma_kk):

        eps_N_p = self._get_state_variables(sctx, eps_app_eng, sigma_kk)[:, 4]
        return eps_N_p

    #----------------------------------------------------------------
    # Returns a list of the sliding strain vector for all microplanes.
    #----------------------------------------------------------------
    def _get_eps_T_pi_arr(self, sctx, eps_app_eng, sigma_kk):

        eps_T_pi_vct_arr = self._get_state_variables(
            sctx, eps_app_eng, sigma_kk)[:, 10:13]

        return eps_T_pi_vct_arr

    #---------------------------------------------------------------------
    # Extra homogenization of damge tensor in case of two damage parameters
    #---------------------------------------------------------------------

    def _get_beta_N_arr(self, sctx, eps_app_eng, sigma_kk):

        # Returns a list of the normal integrity factors for all microplanes.

        beta_N_arr = sqrt(1 -
                          self._get_state_variables(sctx, eps_app_eng, sigma_kk)[:, 0])

        return beta_N_arr

    def _get_beta_T_arr(self, sctx, eps_app_eng, sigma_kk):

        # Returns a list of the tangential integrity factors for all
        # microplanes.

        beta_T_arr = sqrt(1 -
                          self._get_state_variables(sctx, eps_app_eng, sigma_kk)[:, 5])

        return beta_T_arr

    def _get_beta_tns(self, sctx, eps_app_eng, sigma_kk):

        # Returns the 4th order damage tensor 'beta4' using
        #(cf. [Baz99], Eq.(63))

        delta = identity(3)
        beta_N_n = self._get_beta_N_arr(sctx, eps_app_eng, sigma_kk)
        beta_T_n = self._get_beta_T_arr(sctx, eps_app_eng, sigma_kk)

        beta_ijkl = einsum('n, n, ni, nj, nk, nl -> ijkl', self._MPW, beta_N_n, self._MPN, self._MPN, self._MPN, self._MPN) + \
            0.25 * (einsum('n, n, ni, nk, jl -> ijkl', self._MPW, beta_T_n, self._MPN, self._MPN, delta) +
                    einsum('n, n, ni, nl, jk -> ijkl', self._MPW, beta_T_n, self._MPN, self._MPN, delta) +
                    einsum('n, n, nj, nk, il -> ijkl', self._MPW, beta_T_n, self._MPN, self._MPN, delta) +
                    einsum('n, n, nj, nl, ik -> ijkl', self._MPW, beta_T_n, self._MPN, self._MPN, delta) -
                    4 * einsum('n, n, ni, nj, nk, nl -> ijkl', self._MPW, beta_T_n, self._MPN, self._MPN, self._MPN, self._MPN))

        # print 'beta_ijkl =', beta_ijkl
        return beta_ijkl

    #-------------------------------------------------------------
    # Returns a list of the integrity factors for all microplanes.
    #-------------------------------------------------------------

    def _get_phi_arr(self, sctx, eps_app_eng, sigma_kk):

        w_n = self._get_state_variables(sctx, eps_app_eng, sigma_kk)[:, 0]
        w_T = self._get_state_variables(sctx, eps_app_eng, sigma_kk)[:, 5]

        w = zeros(self.n_mp)

#         w = maximum(w_n, w_T)

        eig = np.linalg.eig(eps_app_eng)[0]

        ter_1 = np.sum(eig)

        if ter_1 > 0.0:
            w = w_n
        else:
            w = w_T

        phi_arr = sqrt(1.0 - w)

        return phi_arr

    #----------------------------------------------
    # Returns the 2nd order damage tensor 'phi_mtx'
    #----------------------------------------------
    def _get_phi_mtx(self, sctx, eps_app_eng, sigma_kk):

        # scalar integrity factor for each microplane
        phi_arr = self._get_phi_arr(sctx, eps_app_eng, sigma_kk)

        # integration terms for each microplanes
        phi_ij = einsum('n,n,nij->ij', phi_arr, self._MPW, self._MPNN)

        return phi_ij

    #----------------------------------------------------------------------
    # Returns the 4th order damage tensor 'beta4' using sum-type symmetrization
    #(cf. [Jir99], Eq.(21))
    #----------------------------------------------------------------------
    def _get_beta_tns_sum_type(self, sctx, eps_app_eng, sigma_kk):

        delta = identity(3)

        phi_mtx = self._get_phi_mtx(sctx, eps_app_eng, sigma_kk)

        # use numpy functionality (einsum) to evaluate [Jir99], Eq.(21)
        beta_ijkl = 0.25 * (einsum('ik,jl->ijkl', phi_mtx, delta) +
                            einsum('il,jk->ijkl', phi_mtx, delta) +
                            einsum('jk,il->ijkl', phi_mtx, delta) +
                            einsum('jl,ik->ijkl', phi_mtx, delta))

        return beta_ijkl

    #----------------------------------------------------------------------
    # Returns the 4th order damage tensor 'beta4' using product-type symmetrization
    #(cf. [Baz97], Eq.(87))
    #----------------------------------------------------------------------
    def _get_beta_tns_product_type(self, sctx, eps_app_eng, sigma_kk):

        delta = identity(3)

        phi_mtx = self._get_phi_mtx(sctx, eps_app_eng, sigma_kk)

        n_dim = 3
        phi_eig_value, phi_eig_mtx = eigh(phi_mtx)
        phi_eig_value_real = array([pe.real for pe in phi_eig_value])
        phi_pdc_mtx = zeros((n_dim, n_dim), dtype=float)
        for i in range(n_dim):
            phi_pdc_mtx[i, i] = phi_eig_value_real[i]
        # w_mtx = tensorial square root of the second order damage tensor:
        w_pdc_mtx = sqrt(phi_pdc_mtx)

        # transform the matrix w back to x-y-coordinates:
        w_mtx = einsum('ik,kl,lj -> ij', phi_eig_mtx, w_pdc_mtx, phi_eig_mtx)
        #w_mtx = dot(dot(phi_eig_mtx, w_pdc_mtx), transpose(phi_eig_mtx))

        beta_ijkl = 0.5 * \
            (einsum('ik,jl -> ijkl', w_mtx, w_mtx) +
             einsum('il,jk -> ijkl', w_mtx, w_mtx))

        return beta_ijkl

    #-----------------------------------------------------------
    # Integration of the (inelastic) strains for each microplane
    #-----------------------------------------------------------
    def _get_eps_p_mtx(self, sctx, eps_app_eng, sigma_kk):

        # plastic normal strains
        eps_N_P_n = self._get_eps_N_p_arr(sctx, eps_app_eng, sigma_kk)

        # sliding tangential strains
        eps_T_pi_ni = self._get_eps_T_pi_arr(sctx, eps_app_eng, sigma_kk)
        delta = identity(3)

        #eps_T_pi_ni = np.zeros_like(eps_T_pi_ni)

        # 2-nd order plastic (inelastic) tensor
        eps_p_ij = einsum('n,n,ni,nj -> ij', self._MPW, eps_N_P_n, self._MPN, self._MPN) + \
            0.5 * (einsum('n,nr,ni,rj->ij', self._MPW, eps_T_pi_ni, self._MPN, delta) +
                   einsum('n,nr,nj,ri->ij', self._MPW, eps_T_pi_ni, self._MPN, delta))

        # print 'eps_p', eps_p_ij

        return eps_p_ij

    #-------------------------------------------------------------------------
    # Evaluation - get the corrector and predictor
    #-------------------------------------------------------------------------

    def get_corr_pred(self, sctx, eps_app_eng, sigma_kk):

        # Corrector predictor computation.

        #------------------------------------------------------------------
        # Damage tensor (4th order) using product- or sum-type symmetrization:
        #------------------------------------------------------------------
        beta_ijkl = self._get_beta_tns_sum_type(
            sctx, eps_app_eng, sigma_kk)

#         beta_ijkl = self._get_beta_tns(
#             sctx, eps_app_eng, sigma_kk)

        #------------------------------------------------------------------
        # Damaged stiffness tensor calculated based on the damage tensor beta4:
        #------------------------------------------------------------------
        D4_mdm_ijmn = einsum(
            'ijkl,klsr,mnsr->ijmn', beta_ijkl, self.D4_e, beta_ijkl)

        #----------------------------------------------------------------------
        # Return stresses (corrector) and damaged secant stiffness matrix (predictor)
        #----------------------------------------------------------------------
        # plastic strain tensor
        eps_p_ij = self._get_eps_p_mtx(sctx, eps_app_eng, sigma_kk)

        # elastic strain tensor
        eps_e_mtx = eps_app_eng - eps_p_ij
        # print 'eps_e', eps_e_mtx

        # calculation of the stress tensor
        sig_eng = einsum('ijmn,mn -> ij', D4_mdm_ijmn, eps_e_mtx)

        return sig_eng, D4_mdm_ijmn
class MATS3DMicroplaneDamageJir(MATSXDMicroplaneDamageFatigueJir):

    # implements(IMATSEval)

    #-----------------------------------------------
    # number of microplanes - currently fixed for 3D
    #-----------------------------------------------
    n_mp = Constant(28)

    #-----------------------------------------------
    # get the normal vectors of the microplanes
    #-----------------------------------------------
    _MPN = Property(depends_on='n_mp')

    @cached_property
    def _get__MPN(self):
        # microplane normals:
        return array([[.577350259, .577350259, .577350259],
                      [.577350259, .577350259, -.577350259],
                      [.577350259, -.577350259, .577350259],
                      [.577350259, -.577350259, -.577350259],
                      [.935113132, .250562787, .250562787],
                      [.935113132, .250562787, -.250562787],
                      [.935113132, -.250562787, .250562787],
                      [.935113132, -.250562787, -.250562787],
                      [.250562787, .935113132, .250562787],
                      [.250562787, .935113132, -.250562787],
                      [.250562787, -.935113132, .250562787],
                      [.250562787, -.935113132, -.250562787],
                      [.250562787, .250562787, .935113132],
                      [.250562787, .250562787, -.935113132],
                      [.250562787, -.250562787, .935113132],
                      [.250562787, -.250562787, -.935113132],
                      [.186156720, .694746614, .694746614],
                      [.186156720, .694746614, -.694746614],
                      [.186156720, -.694746614, .694746614],
                      [.186156720, -.694746614, -.694746614],
                      [.694746614, .186156720, .694746614],
                      [.694746614, .186156720, -.694746614],
                      [.694746614, -.186156720, .694746614],
                      [.694746614, -.186156720, -.694746614],
                      [.694746614, .694746614, .186156720],
                      [.694746614, .694746614, -.186156720],
                      [.694746614, -.694746614, .186156720],
                      [.694746614, -.694746614, -.186156720]])

    #-------------------------------------
    # get the weights of the microplanes
    #-------------------------------------
    _MPW = Property(depends_on='n_mp')

    @cached_property
    def _get__MPW(self):
        # Note that the values in the array must be multiplied by 6 (cf. [Baz05])!
        # The sum of of the array equals 0.5. (cf. [BazLuz04]))
        # The values are given for an Gaussian integration over the unit
        # hemisphere.
        return array([.0160714276, .0160714276, .0160714276, .0160714276, .0204744730,
                      .0204744730, .0204744730, .0204744730, .0204744730, .0204744730,
                      .0204744730, .0204744730, .0204744730, .0204744730, .0204744730,
                      .0204744730, .0158350505, .0158350505, .0158350505, .0158350505,
                      .0158350505, .0158350505, .0158350505, .0158350505, .0158350505,
                      .0158350505, .0158350505, .0158350505]) * 6.0

    #-------------------------------------------------------------------------
    # Cached elasticity tensors
    #-------------------------------------------------------------------------

    elasticity_tensors = Property(
        depends_on='E, nu, dimensionality, stress_state')

    @cached_property
    def _get_elasticity_tensors(self):
        '''
        Intialize the fourth order elasticity tensor for 3D or 2D plane strain or 2D plane stress
        '''
        # ----------------------------------------------------------------------------
        # Lame constants calculated from E and nu
        # ----------------------------------------------------------------------------

        # first Lame paramter
        la = self.E * self.nu / ((1 + self.nu) * (1 - 2 * self.nu))
        # second Lame parameter (shear modulus)
        mu = self.E / (2 + 2 * self.nu)

        # -----------------------------------------------------------------------------------------------------
        # Get the fourth order elasticity and compliance tensors for the 3D-case
        # -----------------------------------------------------------------------------------------------------

        # construct the elasticity tensor (using Numpy - einsum function)
        delta = identity(3)
        D_ijkl = (einsum(',ij,kl->ijkl', la, delta, delta) +
                  einsum(',ik,jl->ijkl', mu, delta, delta) +
                  einsum(',il,jk->ijkl', mu, delta, delta))

        return D_ijkl

    #-------------------------------------------------------------------------
    # Dock-based view with its own id
    #-------------------------------------------------------------------------

    traits_view = View(Include('polar_fn_group'),
                       dock='tab',
                       id='ibvpy.mats.mats3D.mats_3D_cmdm.MATS3D_cmdm',
                       kind='modal',
                       resizable=True,
                       scrollable=True,
                       width=0.6, height=0.8,
                       buttons=['OK', 'Cancel']
                       )
Ejemplo n.º 12
0
class SkeletonTask(Task):
    """ A simple task for opening a blank editor.
    """

    #### Task interface #######################################################

    id = 'omnivore.framework.text_edit_task'
    name = 'Skeleton'

    active_editor = Property(Instance(IEditor),
                             depends_on='editor_area.active_editor')

    editor_area = Instance(IEditorAreaPane)

    menu_bar = SMenuBar(
        SMenu(TaskAction(name='New', method='new', accelerator='Ctrl+N'),
              id='File',
              name='&File'),
        SMenu(id='Edit', name='&Edit'),
        SMenu(  #DockPaneToggleGroup(),
            TaskToggleGroup(),
            id='View',
            name='&View'))

    tool_bars = [
        SToolBar(TaskAction(method='new',
                            tooltip='New file',
                            image=ImageResource('document_new')),
                 image_size=(32, 32)),
    ]

    ###########################################################################
    # 'Task' interface.
    ###########################################################################

    def _default_layout_default(self):
        return TaskLayout(left=VSplitter(
            PaneItem('text_edit.pane1'),
            HSplitter(
                PaneItem('text_edit.pane2'),
                PaneItem('text_edit.pane3'),
            ),
        ))

    def create_central_pane(self):
        """ Create the central pane: the text editor.
        """
        self.editor_area = EditorAreaPane()
        return self.editor_area

    def create_dock_panes(self):
        """ Create the file browser and connect to its double click event.
        """
        return [SkeletonPane1(), SkeletonPane2(), SkeletonPane3()]

    ###########################################################################
    # 'ExampleTask' interface.
    ###########################################################################

    def new(self):
        """ Opens a new empty window
        """
        editor = Editor()
        self.editor_area.add_editor(editor)
        self.editor_area.activate_editor(editor)
        self.activated()

    #### Trait property getter/setters ########################################

    def _get_active_editor(self):
        if self.editor_area is not None:
            return self.editor_area.active_editor
        return None
Ejemplo n.º 13
0
class MATSXDMicroplaneDamageFatigueWu(MATSEvalMicroplaneFatigue):
    '''
    Microplane Damage Model.
    '''

    # specification of the model dimension (2D, 3D)
    n_dim = Int

    # specification of number of engineering strain and stress components
    n_eng = Int

    #-------------------------------------------------------------------------
    # Configuration parameters
    #-------------------------------------------------------------------------

    model_version = Enum("compliance", "stiffness")

    symmetrization = Enum("product-type", "sum-type")

    regularization = Bool(False,
                          desc='Flag to use the element length projection'
                          ' in the direction of principle strains',
                          enter_set=True,
                          auto_set=False)

    elastic_debug = Bool(
        False,
        desc='Switch to elastic behavior - used for debugging',
        auto_set=False)

    double_constraint = Bool(
        False,
        desc=
        'Use double constraint to evaluate microplane elastic and fracture energy (Option effects only the response tracers)',
        auto_set=False)

    #-------------------------------------------------------------------------
    # View specification
    #-------------------------------------------------------------------------

    config_param_vgroup = Group(
        Item('model_version', style='custom'),
        #     Item('stress_state', style='custom'),
        Item('symmetrization', style='custom'),
        Item('elastic_debug@'),
        Item('double_constraint@'),
        Spring(resizable=True),
        label='Configuration parameters',
        show_border=True,
        dock='tab',
        id='ibvpy.mats.matsXD.MATSXD_cmdm.config',
    )

    traits_view = View(Include('polar_fn_group'),
                       dock='tab',
                       id='ibvpy.mats.matsXD.MATSXD_cmdm',
                       kind='modal',
                       resizable=True,
                       scrollable=True,
                       width=0.6,
                       height=0.8,
                       buttons=['OK', 'Cancel'])

    #-------------------------------------------------------------------------
    # Setup for computation within a supplied spatial context
    #-------------------------------------------------------------------------

    D4_e = Property

    def _get_D4_e(self):
        # Return the elasticity tensor
        return self.elasticity_tensors[0]

    #-------------------------------------------------------------------------
    # MICROPLANE-DISCRETIZATION RELATED METHOD
    #-------------------------------------------------------------------------

    # get the dyadic product of the microplane normals
    _MPNN = Property(depends_on='n_mp')

    @cached_property
    def _get__MPNN(self):
        # dyadic product of the microplane normals
        # return array([outer(mpn, mpn) for mpn in self._MPN]) # old
        # implementation
        # n identify the microplane
        MPNN_nij = einsum('ni,nj->nij', self._MPN, self._MPN)

        return MPNN_nij

    # get Third order tangential tensor (operator) for each microplane
    _MPTT = Property(depends_on='n_mp')

    @cached_property
    def _get__MPTT(self):
        # Third order tangential tensor for each microplane
        delta = identity(3)
        MPTT_nijr = 0.5 * (
            einsum('ni,jr -> nijr', self._MPN, delta) +
            einsum('nj,ir -> njir', self._MPN, delta) -
            2 * einsum('ni,nj,nr -> nijr', self._MPN, self._MPN, self._MPN))
        return MPTT_nijr

    def _get_e_vct_arr(self, eps_eng):
        # Projection of apparent strain onto the individual microplanes
        e_ni = einsum('nj,ji->ni', self._MPN, eps_eng)
        return e_ni

    def _get_e_N_arr(self, e_vct_arr):
        # get the normal strain array for each microplane

        eN_n = einsum('ni,ni->n', e_vct_arr, self._MPN)
        return eN_n

    def _get_e_T_vct_arr(self, e_vct_arr):
        # get the tangential strain vector array for each microplane
        eN_n = self._get_e_N_arr(e_vct_arr)

        eN_vct_ni = einsum('n,ni->ni', eN_n, self._MPN)

        return e_vct_arr - eN_vct_ni

    # Alternative methods for the kinematic constraint

    def _get_e_N_arr_2(self, eps_eng):

        #eps_mtx = self.map_eps_eng_to_mtx(eps_eng)
        return einsum('nij,ij->n', self._MPNN, eps_eng)

    def _get_e_t_vct_arr_2(self, eps_eng):

        #eps_mtx = self.map_eps_eng_to_mtx(eps_eng)
        MPTT_ijr = self._get__MPTT()
        return einsum('nijr,ij->nr', MPTT_ijr, eps_eng)

    def _get_e_vct_arr_2(self, eps_eng):

        return self._e_N_arr_2 * self._MPN + self._e_t_vct_arr_2

    def _get_I_vol_4(self):
        # The fourth order volumetric-identity tensor
        delta = identity(3)
        I_vol_ijkl = (1.0 / 3.0) * einsum('ij,kl -> ijkl', delta, delta)
        return I_vol_ijkl

    def _get_I_dev_4(self):
        # The fourth order deviatoric-identity tensor
        delta = identity(3)
        I_dev_ijkl = 0.5 * (einsum('ik,jl -> ijkl', delta, delta) +
                            einsum('il,jk -> ijkl', delta, delta)) \
            - (1 / 3.0) * einsum('ij,kl -> ijkl', delta, delta)

        return I_dev_ijkl

    def _get_P_vol(self):
        delta = identity(3)
        P_vol_ij = (1 / 3.0) * delta
        return P_vol_ij

    def _get_P_dev(self):
        delta = identity(3)
        P_dev_njkl = 0.5 * einsum('ni,ij,kl -> njkl', self._MPN, delta, delta)
        return P_dev_njkl

    def _get_PP_vol_4(self):
        # outer product of P_vol
        delta = identity(3)
        PP_vol_ijkl = (1 / 9.) * einsum('ij,kl -> ijkl', delta, delta)
        return PP_vol_ijkl

    def _get_PP_dev_4(self):
        # inner product of P_dev
        delta = identity(3)
        PP_dev_nijkl = 0.5 * (0.5 * (einsum('ni,nk,jl -> nijkl', self._MPN, self._MPN, delta) +
                                     einsum('ni,nl,jk -> nijkl', self._MPN, self._MPN, delta)) +
                              0.5 * (einsum('ik,nj,nl -> nijkl',  delta, self._MPN, self._MPN) +
                                     einsum('il,nj,nk -> nijkl',  delta, self._MPN, self._MPN))) -\
            (1 / 3.) * (einsum('ni,nj,kl -> nijkl', self._MPN, self._MPN, delta) +
                        einsum('ij,nk,nl -> nijkl', delta, self._MPN, self._MPN)) +\
            (1 / 9.) * einsum('ij,kl -> ijkl', delta, delta)

        return PP_dev_nijkl

    def _get_e_equiv_arr(self, e_vct_arr):
        '''
        Returns a list of the microplane equivalent strains
        based on the list of microplane strain vectors
        '''
        # magnitude of the normal strain vector for each microplane
        # @todo: faster numpy functionality possible?
        e_N_arr = self._get_e_N_arr(e_vct_arr)
        # positive part of the normal strain magnitude for each microplane
        e_N_pos_arr = (fabs(e_N_arr) + e_N_arr) / 2
        # normal strain vector for each microplane
        # @todo: faster numpy functionality possible?
        e_N_vct_arr = einsum('n,ni -> ni', e_N_arr, self._MPN)
        # tangent strain ratio
        c_T = self.c_T
        # tangential strain vector for each microplane
        e_T_vct_arr = e_vct_arr - e_N_vct_arr
        # squared tangential strain vector for each microplane
        e_TT_arr = einsum('ni,ni -> n', e_T_vct_arr, e_T_vct_arr)
        # equivalent strain for each microplane
        e_equiv_arr = sqrt(e_N_pos_arr * e_N_pos_arr + c_T * e_TT_arr)
        return e_equiv_arr

    def _get_e_max(self, e_equiv_arr, e_max_arr):
        '''
        Compares the equivalent microplane strain of a single microplane with the
        maximum strain reached in the loading history for the entire array
        '''
        bool_e_max = e_equiv_arr >= e_max_arr

        # [rch] fixed a bug here - this call was modifying the state array
        # at any invocation.
        #
        # The new value must be created, otherwise side-effect could occur
        # by writing into a state array.
        #
        new_e_max_arr = copy(e_max_arr)
        new_e_max_arr[bool_e_max] = e_equiv_arr[bool_e_max]
        return new_e_max_arr

    def _get_state_variables(self, sctx, eps_app_eng):
        '''
        Compares the list of current equivalent microplane strains with
        the values in the state array and returns the higher values
        '''
        e_vct_arr = self._get_e_vct_arr(eps_app_eng)
        e_equiv_arr = self._get_e_equiv_arr(e_vct_arr)
        #e_max_arr_old = sctx.mats_state_array
        #e_max_arr_new = self._get_e_max(e_equiv_arr, e_max_arr_old)
        return e_equiv_arr

    def _get_phi_arr(self, sctx, eps_app_eng):
        # Returns a list of the integrity factors for all microplanes.
        e_max_arr = self._get_state_variables(sctx, eps_app_eng)

        phi_arr = self._get_phi(e_max_arr, sctx)[0]

        # print 'phi_arr', phi_arr

        return phi_arr

    def _get_phi_mtx(self, sctx, eps_app_eng):
        # Returns the 2nd order damage tensor 'phi_mtx'

        # scalar integrity factor for each microplane
        phi_arr = self._get_phi_arr(sctx, eps_app_eng)

        # integration terms for each microplanes
        phi_ij = einsum('n,n,nij->ij', phi_arr, self._MPW, self._MPNN)

        # print 'phi_ij', phi_ij

        return phi_ij

    def _get_d_scalar(self, sctx, eps_app_eng):

        # scalar integrity factor for each microplane
        phi_arr = self._get_phi_arr(sctx, eps_app_eng)

        d_arr = 1.0 - phi_arr

        d = (1.0 / 3.0) * einsum('n,n->', d_arr, self._MPW)

        # print d

        return d

    def _get_M_vol_tns(self, sctx, eps_app_eng):

        d = self._get_d_scalar(sctx, eps_app_eng)
        delta = identity(3)

        # print d
        #I_4th_ijkl = einsum('ik,jl -> ijkl', delta, delta)

        I_4th_ijkl = 0.5 * (einsum('ik,jl -> ijkl', delta, delta) +
                            einsum('il,jk -> ijkl', delta, delta))

        return (1 - d) * I_4th_ijkl

    def _get_M_dev_tns(self, phi_mtx):
        '''
        Returns the 4th order deviatoric damage tensor
        '''
        delta = identity(3)

        # use numpy functionality (einsum) to evaluate [Jir99], Eq.(21)
        # M_dev_ijkl = 0.25 * (einsum('ik,jl->ijkl', phi_mtx, delta) +
        #                    einsum('il,jk->ijkl', phi_mtx, delta) +
        #                     einsum('jk,il->ijkl', phi_mtx, delta) +
        #                     einsum('jl,ik->ijkl', phi_mtx, delta))
        I_4th_ijkl = 0.5 * (einsum('ik,jl -> ijkl', delta, delta) +
                            einsum('il,jk -> ijkl', delta, delta))

        M_dev_ijkl = self.zeta_G * (0.5 * (einsum('ik,jl->ijkl', delta, phi_mtx) +
                                           einsum('il,jk->ijkl', delta, phi_mtx)) +
                                    0.5 * (einsum('ik,jl->ijkl', phi_mtx, delta) +
                                           einsum('il,jk->ijkl', phi_mtx, delta))) -\
            (2.0 * self.zeta_G - 1.0) * I_4th_ijkl * \
            (1.0 / 3.0) * trace(phi_mtx)

        # print 'M_dev_ijkl', M_dev_ijkl

        return M_dev_ijkl

    #-------------------------------------------------------------------------
    # Secant stiffness (irreducible decomposition based on ODFs)
    #-------------------------------------------------------------------------

    def _get_S_1_tns(self, sctx, eps_app_eng):
        #----------------------------------------------------------------------
        # Returns the fourth order secant stiffness tensor (eq.1)
        #----------------------------------------------------------------------
        K0 = self.E / (1. - 2. * self.nu)
        G0 = self.E / (1. + self.nu)

        e_max_arr = self._get_state_variables(sctx, eps_app_eng)
        d_n = self._get_phi(e_max_arr, sctx)[1]

        PP_vol_4 = self._get_PP_vol_4()
        PP_dev_4 = self._get_PP_dev_4()
        delta = identity(3)
        I_4th_ijkl = einsum('ik,jl -> ijkl', delta, delta)
        I_dev_4 = self._get_I_dev_4()

        S_1_ijkl = K0 * einsum('n,n,ijkl->ijkl', d_n, self._MPW, PP_vol_4) + \
            G0 * 2. * self.zeta_G * einsum('n,n,nijkl->ijkl', d_n, self._MPW, PP_dev_4) - (1. / 3.) * (
                2. * self.zeta_G - 1.) * G0 * einsum('n,n,ijkl->ijkl', d_n, self._MPW, I_dev_4)

        return S_1_ijkl

    def _get_S_2_tns(self, sctx, eps_app_eng):
        #----------------------------------------------------------------------
        # Returns the fourth order secant stiffness tensor (eq.2)
        #----------------------------------------------------------------------
        K0 = self.E / (1. - 2. * self.nu)
        G0 = self.E / (1. + self.nu)

        I_vol_ijkl = self._get_I_vol_4()
        I_dev_ijkl = self._get_I_dev_4()
        phi_mtx = self._get_phi_mtx(sctx, eps_app_eng)
        M_vol_ijkl = self._get_M_vol_tns(sctx, eps_app_eng)
        M_dev_ijkl = self._get_M_dev_tns(phi_mtx)

        S_2_ijkl = K0 * einsum('ijmn,mnrs,rskl -> ijkl', I_vol_ijkl, M_vol_ijkl, I_vol_ijkl) \
            + G0 * einsum('ijmn,mnrs,rskl -> ijkl', I_dev_ijkl, M_dev_ijkl, I_dev_ijkl)\

        # print 'S_vol = ', K0 * einsum('ijmn,mnrs,rskl -> ijkl', I_vol_ijkl, M_vol_ijkl, I_vol_ijkl)
        # print 'S_dev = ', G0 * einsum('ijmn,mnrs,rskl -> ijkl', I_dev_ijkl,
        # M_dev_ijkl, I_dev_ijkl)

        return S_2_ijkl

    def _get_S_22_tns(self, sctx, eps_app_eng):
        #----------------------------------------------------------------------
        # Returns the fourth order secant stiffness tensor (double orthotropic)
        #----------------------------------------------------------------------

        K0 = self.E / (1. - 2. * self.nu)
        G0 = self.E / (1. + self.nu)

        I_vol_ijkl = self._get_I_vol_4()
        I_dev_ijkl = self._get_I_dev_4()
        delta = identity(3)
        phi_mtx = self._get_phi_mtx(sctx, eps_app_eng)
        # print 'phi_mtx ', phi_mtx
        # print '------------------------'
        D_ij = delta - phi_mtx
        # print 'D_ij ', D_ij
        # print '------------------------'
        d = (1. / 3.) * trace(D_ij)
        # print 'd ', d
        # print '------------------------'
        D_bar_ij = self.zeta_G * (D_ij - d * delta)
        # print 'D_bar_ij ', D_bar_ij
        # print '------------------------'
        # damaged stiffness without simplification
        S_22_ijkl = (1. - d) * K0 * I_vol_ijkl + (1. - d) * G0 * I_dev_ijkl + (2. / 3.) * (G0) * \
            (einsum('ij,kl -> ijkl', delta, D_bar_ij) +
             einsum('ij,kl -> ijkl', D_bar_ij, delta)) - G0 * \
            (0.5 * (einsum('ik,jl -> ijkl', delta, D_bar_ij) + einsum('il,jk -> ijkl', D_bar_ij, delta)) +
             0.5 * (einsum('ik,jl -> ijkl', D_bar_ij, delta) + einsum('il,jk -> ijkl', delta, D_bar_ij)))

        # print einsum('ik,jl -> ijkl', delta, D_bar_ij)

        #print ((1. - d) * delta - phi_mtx)
        # print delta

        M_dev_ijkl = self._get_M_dev_tns(phi_mtx)
        M_vol_ijkl = self._get_M_vol_tns(sctx, eps_app_eng)
        S_2_ijkl = K0 * einsum('ijmn,mnrs,rskl -> ijkl', I_vol_ijkl, M_vol_ijkl, I_vol_ijkl) \
            + G0 * \
            einsum(
                'ijmn,mnrs,rskl -> ijkl', I_dev_ijkl, M_dev_ijkl, I_dev_ijkl)

        # print 'S_22_ijkl', S_22_ijkl
        # print '------------------------'
        # print 'S_2_ijkl', S_2_ijkl
        return S_22_ijkl

    def _get_S_3_tns(self, sctx, eps_app_eng):
        #----------------------------------------------------------------------
        # Returns the fourth order secant stiffness tensor (eq.3)
        #----------------------------------------------------------------------
        K0 = self.E / (1. - 2. * self.nu)
        G0 = self.E / (1. + self.nu)

        I_vol_ijkl = self._get_I_vol_4()
        I_dev_ijkl = self._get_I_dev_4()

        S_0_ijkl = K0 * I_vol_ijkl + G0 * I_dev_ijkl

        e_max_arr = self._get_state_variables(sctx, eps_app_eng)

        d_n = 1.0 - self._get_phi(e_max_arr, sctx)[0]

        PP_vol_4 = self._get_PP_vol_4()
        PP_dev_4 = self._get_PP_dev_4()

        delta = identity(3)
        I_4th_ijkl = einsum('ik,jl -> ijkl', delta, delta)
        # print 'I_4th_ijkl', I_4th_ijkl

        D_ijkl = einsum('n,n,ijkl->ijkl', d_n, self._MPW, PP_vol_4) + \
            einsum('n,n,nijkl->ijkl', d_n, self._MPW, PP_dev_4)
        # print 'D_ijkl', D_ijkl
        phi_ijkl = (I_4th_ijkl - D_ijkl)
        # print 'phi_ijkl', phi_ijkl
        S_ijkl = einsum('ijmn,mnkl', phi_ijkl, S_0_ijkl)

        PP_vol_int = einsum('n,abcd->abcd', self._MPW, PP_vol_4)
        # print 'PP_vol_int', PP_vol_int
        # print 'I_vol_ijkl', I_vol_ijkl

        return S_ijkl

    def _get_S_4_tns(self, sctx, eps_app_eng):
        #----------------------------------------------------------------------
        # Returns the fourth order secant stiffness tensor (double orthotropic)
        #----------------------------------------------------------------------

        K0 = self.E / (1.0 - 2.0 * self.nu)
        G0 = self.E / (1.0 + self.nu)

        I_vol_ijkl = self._get_I_vol_4()
        I_dev_ijkl = self._get_I_dev_4()
        delta = identity(3)
        phi_mtx = self._get_phi_mtx(sctx, eps_app_eng)
        # print 'phi_mtx', phi_mtx
        D_ij = delta - phi_mtx
        d = (1.0 / 3.0) * trace(D_ij)
        print('d_1', d)
        d = self._get_d_scalar(sctx, eps_app_eng)
        print('d_2', d)

        D_bar_ij = self.zeta_G * (D_ij - d * delta)
        # print 'D_bar_ij', D_bar_ij
        # print 'D_ij', D_ij
        S_4_ijkl = (1. - d) * K0 * I_vol_ijkl + (1. - d) * G0 * I_dev_ijkl + (2. / 3.) * (G0 - K0) * \
            (einsum('ij,kl -> ijkl', delta, D_bar_ij) +
             einsum('ij,kl -> ijkl', D_bar_ij, delta)) + 0.5 * (- K0 + 2. * G0) * \
            (0.5 * (einsum('ik,jl -> ijkl', delta, D_bar_ij) + einsum('il,jk -> ijkl', D_bar_ij, delta)) +
             0.5 * (einsum('ik,jl -> ijkl', D_bar_ij, delta) + einsum('il,jk-> ijkl', delta, D_bar_ij)))

        # print einsum('ik,jl -> ijkl', delta, D_bar_ij)
        # print einsum('ik,jl -> ijkl', delta, D_bar_ij)

        #print ((1. - d) * delta - phi_mtx)
        # print delta

        return S_4_ijkl

    def _get_S_44_tns(self, sctx, eps_app_eng):
        #----------------------------------------------------------------------
        # Returns the fourth order secant stiffness tensor (restrctive orthotropic)
        #----------------------------------------------------------------------

        K0 = self.E / (1. - 2. * self.nu)
        G0 = self.E / (1. + self.nu)

        I_vol_ijkl = self._get_I_vol_4()
        I_dev_ijkl = self._get_I_dev_4()
        delta = identity(3)
        phi_mtx = self._get_phi_mtx(sctx, eps_app_eng)

        D_ij = delta - phi_mtx

        # damaged stiffness without simplification
        S_44_ijkl = (1. / 3.) * (K0 - G0) * 0.5 * ((einsum('ij,kl -> ijkl', delta, phi_mtx) +
                                                    einsum('ij,kl -> ijkl', phi_mtx, delta))) + \
            G0 * 0.5 * ((0.5 * (einsum('ik,jl -> ijkl', delta, phi_mtx) + einsum('il,jk -> ijkl', phi_mtx, delta)) +
                         0.5 * (einsum('ik,jl -> ijkl', phi_mtx, delta) + einsum('il,jk  -> ijkl', delta, phi_mtx))))

        return S_44_ijkl

    #-------------------------------------------------------------------------
    # Evaluation - get the corrector and predictor
    #-------------------------------------------------------------------------

    def get_corr_pred(self, sctx, eps_app_eng):
        '''
        Corrector predictor computation.
        @param eps_app_eng input variable - engineering strain
        '''

        # -----------------------------------------------------------------------------------------------
        # for debugging purposes only: if elastic_debug is switched on, linear elastic material is used
        # -----------------------------------------------------------------------------------------------
        if self.elastic_debug:
            # NOTE: This must be copied otherwise self.D2_e gets modified when
            # essential boundary conditions are inserted
            D2_e = copy(self.D2_e)
            sig_eng = tensordot(D2_e, eps_app_eng, [[1], [0]])
            return sig_eng, D2_e

        #----------------------------------------------------------------------
        # if the regularization using the crack-band concept is on calculate the
        # effective element length in the direction of principle strains
        #----------------------------------------------------------------------
        # if self.regularization:
        #    h = self.get_regularizing_length(sctx, eps_app_eng)
        #    self.phi_fn.h = h

        #----------------------------------------------------------------------
        # Return stresses (corrector) and damaged secant stiffness matrix (predictor)
        #----------------------------------------------------------------------
        eps_e_mtx = eps_app_eng

        S_ijkl = self._get_S_4_tns(sctx, eps_app_eng)[:]
        sig_ij = einsum('ijkl,kl -> ij', S_ijkl, eps_e_mtx)

        return sig_ij
Ejemplo n.º 14
0
class MATS3DMicroplaneDamageWu(MATSXDMicroplaneDamageFatigueWu):

    # number of spatial dimensions
    #
    n_dim = Constant(3)

    # number of components of engineering tensor representation
    #
    n_eng = Constant(6)

    #-------------------------------------------------------------------------
    # PolarDiscr related data
    #-------------------------------------------------------------------------
    #
    # number of microplanes - currently fixed for 3D
    #
    n_mp = Constant(28)

    # get the normal vectors of the microplanes
    #
    _MPN = Property(depends_on='n_mp')

    @cached_property
    def _get__MPN(self):
        # microplane normals:
        return array([[.577350259, .577350259, .577350259],
                      [.577350259, .577350259, -.577350259],
                      [.577350259, -.577350259, .577350259],
                      [.577350259, -.577350259, -.577350259],
                      [.935113132, .250562787, .250562787],
                      [.935113132, .250562787, -.250562787],
                      [.935113132, -.250562787, .250562787],
                      [.935113132, -.250562787, -.250562787],
                      [.250562787, .935113132, .250562787],
                      [.250562787, .935113132, -.250562787],
                      [.250562787, -.935113132, .250562787],
                      [.250562787, -.935113132, -.250562787],
                      [.250562787, .250562787, .935113132],
                      [.250562787, .250562787, -.935113132],
                      [.250562787, -.250562787, .935113132],
                      [.250562787, -.250562787, -.935113132],
                      [.186156720, .694746614, .694746614],
                      [.186156720, .694746614, -.694746614],
                      [.186156720, -.694746614, .694746614],
                      [.186156720, -.694746614, -.694746614],
                      [.694746614, .186156720, .694746614],
                      [.694746614, .186156720, -.694746614],
                      [.694746614, -.186156720, .694746614],
                      [.694746614, -.186156720, -.694746614],
                      [.694746614, .694746614, .186156720],
                      [.694746614, .694746614, -.186156720],
                      [.694746614, -.694746614, .186156720],
                      [.694746614, -.694746614, -.186156720]])

    # get the weights of the microplanes
    #
    _MPW = Property(depends_on='n_mp')

    @cached_property
    def _get__MPW(self):
        # Note that the values in the array must be multiplied by 6 (cf. [Baz05])!
        # The sum of of the array equals 0.5. (cf. [BazLuz04]))
        # The values are given for an Gaussian integration over the unit
        # hemisphere.
        return array([
            .0160714276, .0160714276, .0160714276, .0160714276, .0204744730,
            .0204744730, .0204744730, .0204744730, .0204744730, .0204744730,
            .0204744730, .0204744730, .0204744730, .0204744730, .0204744730,
            .0204744730, .0158350505, .0158350505, .0158350505, .0158350505,
            .0158350505, .0158350505, .0158350505, .0158350505, .0158350505,
            .0158350505, .0158350505, .0158350505
        ]) * 6.0

    #-------------------------------------------------------------------------
    # Cached elasticity tensors
    #-------------------------------------------------------------------------

    #-------------------------------------------------------------------------
    # Dock-based view with its own id
    #-------------------------------------------------------------------------

    traits_view = View(Include('polar_fn_group'),
                       dock='tab',
                       id='ibvpy.mats.mats3D.mats_3D_cmdm.MATS3D_cmdm',
                       kind='modal',
                       resizable=True,
                       scrollable=True,
                       width=0.6,
                       height=0.8,
                       buttons=['OK', 'Cancel'])
Ejemplo n.º 15
0
class GaussianMixture1DView(cytoflow.views.HistogramView):
    """
    Attributes
    ----------    
    op : Instance(GaussianMixture1DOp)
        The op whose parameters we're viewing.
    """

    id = 'edu.mit.synbio.cytoflow.view.gaussianmixture1dview'
    friendly_id = "1D Gaussian Mixture Diagnostic Plot"

    # TODO - why can't I use GaussianMixture1DOp here?
    op = Instance(IOperation)
    channel = DelegatesTo('op')
    scale = DelegatesTo('op')

    _by = Property(List)

    def _get__by(self):
        facets = filter(lambda x: x, [self.xfacet, self.yfacet])
        return list(set(self.op.by) - set(facets))

    def enum_plots(self, experiment):
        """
        Returns an iterator over the possible plots that this View can
        produce.  The values returned can be passed to "plot".
        """

        if self.xfacet and self.xfacet not in experiment.conditions:
            raise util.CytoflowViewError(
                "X facet {} not in the experiment".format(self.xfacet))

        if self.xfacet and self.xfacet not in self.op.by:
            raise util.CytoflowViewError(
                "X facet {} must be in GaussianMixture1DOp.by, which is {}".
                format(self.xfacet, self.op.by))

        if self.yfacet and self.yfacet not in experiment.conditions:
            raise util.CytoflowViewError(
                "Y facet {0} not in the experiment".format(self.yfacet))

        if self.yfacet and self.yfacet not in self.op.by:
            raise util.CytoflowViewError(
                "Y facet {} must be in GaussianMixture1DOp.by, which is {}".
                format(self.yfacet, self.op.by))

        for b in self.op.by:
            if b not in experiment.data:
                raise util.CytoflowOpError("Aggregation metadata {0} not found"
                                           " in the experiment".format(b))

        class plot_enum(object):
            def __init__(self, view, experiment):
                self._iter = None
                self._returned = False

                if view._by:
                    self._iter = experiment.data.groupby(view._by).__iter__()

            def __iter__(self):
                return self

            def next(self):
                if self._iter:
                    return self._iter.next()[0]
                else:
                    if self._returned:
                        raise StopIteration
                    else:
                        self._returned = True
                        return None

        return plot_enum(self, experiment)

    def plot(self, experiment, plot_name=None, **kwargs):
        """
        Plot the plots.
        """
        if not experiment:
            raise util.CytoflowViewError("No experiment specified")

        if not self.op.channel:
            raise util.CytoflowViewError("No channel specified")

        experiment = experiment.clone()

        # try to apply the current operation
        try:
            experiment = self.op.apply(experiment)
        except util.CytoflowOpError:
            # could have failed because no GMMs have been estimated, or because
            # op has already been applied
            pass

        # if apply() succeeded (or wasn't needed), set up the hue facet
        if self.op.name and self.op.name in experiment.conditions:
            if self.huefacet and self.huefacet != self.op.name:
                warn(
                    "Resetting huefacet to the model component (was {}, now {})."
                    .format(self.huefacet, self.op.name))
            self.huefacet = self.op.name

        if self.subset:
            try:
                experiment = experiment.query(self.subset)
                experiment.data.reset_index(drop=True, inplace=True)
            except:
                raise util.CytoflowViewError(
                    "Subset string '{0}' isn't valid".format(self.subset))

            if len(experiment) == 0:
                raise util.CytoflowViewError(
                    "Subset string '{0}' returned no events".format(
                        self.subset))

        # figure out common x limits for multiple plots
        # adjust the limits to clip extreme values
        min_quantile = kwargs.pop("min_quantile", 0.001)
        max_quantile = kwargs.pop("max_quantile", 0.999)

        xlim = kwargs.pop("xlim", None)
        if xlim is None:
            xlim = (experiment.data[self.op.channel].quantile(min_quantile),
                    experiment.data[self.op.channel].quantile(max_quantile))

        # see if we're making subplots
        if self._by and plot_name is None:
            for plot in self.enum_plots(experiment):
                self.plot(experiment, plot, xlim=xlim, **kwargs)
                plt.title("{0} = {1}".format(self.op.by, plot))
            return

        if plot_name is not None:
            if plot_name is not None and not self._by:
                raise util.CytoflowViewError(
                    "Plot {} not from plot_enum".format(plot_name))

            groupby = experiment.data.groupby(self._by)

            if plot_name not in set(groupby.groups.keys()):
                raise util.CytoflowViewError(
                    "Plot {} not from plot_enum".format(plot_name))

            experiment.data = groupby.get_group(plot_name)
            experiment.data.reset_index(drop=True, inplace=True)

        # get the parameterized scale object back from the op
        scale = self.op._scale

        # plot the histogram, whether or not we're plotting distributions on top

        g = super(GaussianMixture1DView, self).plot(experiment,
                                                    scale=scale,
                                                    xlim=xlim,
                                                    **kwargs)

        # plot the actual distribution on top of it.

        row_names = g.row_names if g.row_names else [False]
        col_names = g.col_names if g.col_names else [False]

        for (i, row), (j, col) in product(enumerate(row_names),
                                          enumerate(col_names)):

            facets = filter(lambda x: x, [row, col])
            if plot_name is not None:
                try:
                    gmm_name = tuple(list(plot_name) + facets)
                except TypeError:  # plot_name isn't a list
                    gmm_name = tuple(list([plot_name]) + facets)
            else:
                gmm_name = tuple(facets)

            if len(gmm_name) == 1:
                gmm_name = gmm_name[0]

            if gmm_name:
                if gmm_name in self.op._gmms:
                    gmm = self.op._gmms[gmm_name]
                else:
                    # there weren't any events in this subset to estimate a GMM from
                    warn("No estimated GMM for plot {}".format(gmm_name),
                         util.CytoflowViewWarning)
                    return g
            else:
                if True in self.op._gmms:
                    gmm = self.op._gmms[True]
                else:
                    return g

            ax = g.facet_axis(i, j)

            for k in range(0, len(gmm.means_)):
                # we want to scale the plots so they have the same area under the
                # curve as the histograms.  it used to be that we got the area from
                # repeating the assignments, then calculating bin widths, etc.  but
                # really, if we just plotted the damn thing already, we can get the
                # area of the plot from the Polygon patch that we just plotted!

                patch = ax.patches[k]
                xy = patch.get_xy()
                pdf_scale = poly_area([scale(p[0]) for p in xy],
                                      [p[1] for p in xy])

                # cheat a little
                #                 pdf_scale *= 1.1

                plt_min, plt_max = plt.gca().get_xlim()
                x = scale.inverse(
                    np.linspace(scale(plt_min), scale(plt_max), 500))

                mean = gmm.means_[k][0]
                stdev = np.sqrt(gmm.covariances_[k][0])
                y = stats.norm.pdf(scale(x), mean, stdev) * pdf_scale
                color_k = k % len(sns.color_palette())
                color = sns.color_palette()[color_k]
                ax.plot(x, y, color=color)

        return g
Ejemplo n.º 16
0
class Slider(Component):
    """ A horizontal or vertical slider bar """

    #------------------------------------------------------------------------
    # Model traits
    #------------------------------------------------------------------------

    min = Float()

    max = Float()

    value = Float()

    # The number of ticks to show on the slider.
    num_ticks = Int(4)

    #------------------------------------------------------------------------
    # Bar and endcap appearance
    #------------------------------------------------------------------------

    # Whether this is a horizontal or vertical slider
    orientation = Enum("h", "v")

    # The thickness, in pixels, of the lines used to render the ticks,
    # endcaps, and main slider bar.
    bar_width = Int(4)

    bar_color = ColorTrait("black")

    # Whether or not to render endcaps on the slider bar
    endcaps = Bool(True)

    # The extent of the endcaps, in pixels.  This is a read-only property,
    # since the endcap size can be set as either a fixed number of pixels or
    # a percentage of the widget's size in the transverse direction.
    endcap_size = Property

    # The extent of the tickmarks, in pixels.  This is a read-only property,
    # since the endcap size can be set as either a fixed number of pixels or
    # a percentage of the widget's size in the transverse direction.
    tick_size = Property

    #------------------------------------------------------------------------
    # Slider appearance
    #------------------------------------------------------------------------

    # The kind of marker to use for the slider.
    slider = SliderMarkerTrait("rect")

    # If the slider marker is "rect", this is the thickness of the slider,
    # i.e. its extent in the dimension parallel to the long axis of the widget.
    # For other slider markers, this has no effect.
    slider_thickness = Int(9)

    # The size of the slider, in pixels.  This is a read-only property, since
    # the slider size can be set as either a fixed number of pixels or a
    # percentage of the widget's size in the transverse direction.
    slider_size = Property

    # For slider markers with a filled area, this is the color of the filled
    # area.  For slider markers that are just lines/strokes (e.g. cross, plus),
    # this is the color of the stroke.
    slider_color = ColorTrait("red")

    # For slider markers with a filled area, this is the color of the outline
    # border drawn around the filled area.  For slider markers that have just
    # lines/strokes, this has no effect.
    slider_border = ColorTrait("none")

    # For slider markers with a filled area, this is the width, in pixels,
    # of the outline around the area.  For slider markers that are just lines/
    # strokes, this is the thickness of the stroke.
    slider_outline_width = Int(1)

    # The kiva.CompiledPath representing the custom path to render for the
    # slider, if the **slider** trait is set to "custom".
    custom_slider = Any()

    #------------------------------------------------------------------------
    # Interaction traits
    #------------------------------------------------------------------------

    # Can this slider be interacted with, or is it just a display
    interactive = Bool(True)

    mouse_button = Enum("left", "right")

    event_state = Enum("normal", "dragging")

    #------------------------------------------------------------------------
    # Private traits
    #------------------------------------------------------------------------

    # Returns the coordinate index (0 or 1) corresponding to our orientation.
    # Used internally; read-only property.
    axis_ndx = Property()

    _slider_size_mode = Enum("fixed", "percent")
    _slider_percent = Float(0.0)
    _cached_slider_size = Int(10)

    _endcap_size_mode = Enum("fixed", "percent")
    _endcap_percent = Float(0.0)
    _cached_endcap_size = Int(20)

    _tick_size_mode = Enum("fixed", "percent")
    _tick_size_percent = Float(0.0)
    _cached_tick_size = Int(20)

    # A tuple of (dx, dy) of the difference between the mouse position and
    # center of the slider.
    _offset = Any((0, 0))

    def set_range(self, min, max):
        self.min = min
        self.max = max

    def map_screen(self, val):
        """ Returns an (x,y) coordinate corresponding to the location of
        **val** on the slider.
        """
        # Some local variables to handle orientation dependence
        axis_ndx = self.axis_ndx
        other_ndx = 1 - axis_ndx
        screen_low = self.position[axis_ndx]
        screen_high = screen_low + self.bounds[axis_ndx]

        # The return coordinate.  The return value along the non-primary
        # axis will be the same in all cases.
        coord = [0, 0]
        coord[
            other_ndx] = self.position[other_ndx] + self.bounds[other_ndx] / 2

        # Handle exceptional/boundary cases
        if val <= self.min:
            coord[axis_ndx] = screen_low
            return coord
        elif val >= self.max:
            coord[axis_ndx] = screen_high
            return coord
        elif self.min == self.max:
            coord[axis_ndx] = (screen_low + screen_high) / 2
            return coord

        # Handle normal cases
        coord[axis_ndx] = (val - self.min) / (
            self.max - self.min) * self.bounds[axis_ndx] + screen_low
        return coord

    def map_data(self, x, y, clip=True):
        """ Returns a value between min and max that corresponds to the given
        x and y values.

        Parameters
        ==========
        x, y : Float
            The screen coordinates to map
        clip : Bool (default=True)
            Whether points outside the range should be clipped to the max
            or min value of the slider (depending on which it's closer to)

        Returns
        =======
        value : Float
        """
        # Some local variables to handle orientation dependence
        axis_ndx = self.axis_ndx
        other_ndx = 1 - axis_ndx
        screen_low = self.position[axis_ndx]
        screen_high = screen_low + self.bounds[axis_ndx]
        if self.orientation == "h":
            coord = x
        else:
            coord = y

        # Handle exceptional/boundary cases
        if coord >= screen_high:
            return self.max
        elif coord <= screen_low:
            return self.min
        elif screen_high == screen_low:
            return (self.max + self.min) / 2

        # Handle normal cases
        return (coord - screen_low) /self.bounds[axis_ndx] * \
                    (self.max - self.min) + self.min

    def set_slider_pixels(self, pixels):
        """ Sets the width of the slider to be a fixed number of pixels

        Parameters
        ==========
        pixels : int
            The number of pixels wide that the slider should be
        """
        self._slider_size_mode = "fixed"
        self._cached_slider_size = pixels

    def set_slider_percent(self, percent):
        """ Sets the width of the slider to be a percentage of the width
        of the slider widget.

        Parameters
        ==========
        percent : float
            The percentage, between 0.0 and 1.0
        """
        self._slider_size_mode = "percent"
        self._slider_percent = percent
        self._update_sizes()

    def set_endcap_pixels(self, pixels):
        """ Sets the width of the endcap to be a fixed number of pixels

        Parameters
        ==========
        pixels : int
            The number of pixels wide that the endcap should be
        """
        self._endcap_size_mode = "fixed"
        self._cached_endcap_size = pixels

    def set_endcap_percent(self, percent):
        """ Sets the width of the endcap to be a percentage of the width
        of the endcap widget.

        Parameters
        ==========
        percent : float
            The percentage, between 0.0 and 1.0
        """
        self._endcap_size_mode = "percent"
        self._endcap_percent = percent
        self._update_sizes()

    def set_tick_pixels(self, pixels):
        """ Sets the width of the tick marks to be a fixed number of pixels

        Parameters
        ==========
        pixels : int
            The number of pixels wide that the endcap should be
        """
        self._tick_size_mode = "fixed"
        self._cached_tick_size = pixels

    def set_tick_percent(self, percent):
        """ Sets the width of the tick marks to be a percentage of the width
        of the endcap widget.

        Parameters
        ==========
        percent : float
            The percentage, between 0.0 and 1.0
        """
        self._tick_size_mode = "percent"
        self._tick_percent = percent
        self._update_sizes()

    #------------------------------------------------------------------------
    # Rendering methods
    #------------------------------------------------------------------------

    def _draw_mainlayer(self, gc, view_bounds=None, mode="normal"):
        start = [0, 0]
        end = [0, 0]
        axis_ndx = self.axis_ndx
        other_ndx = 1 - axis_ndx

        bar_x = self.x + self.width / 2
        bar_y = self.y + self.height / 2

        # Draw the bar and endcaps
        gc.set_stroke_color(self.bar_color_)
        gc.set_line_width(self.bar_width)
        if self.orientation == "h":
            gc.move_to(self.x, bar_y)
            gc.line_to(self.x2, bar_y)
            gc.stroke_path()
            if self.endcaps:
                start_y = bar_y - self._cached_endcap_size / 2
                end_y = bar_y + self._cached_endcap_size / 2
                gc.move_to(self.x, start_y)
                gc.line_to(self.x, end_y)
                gc.move_to(self.x2, start_y)
                gc.line_to(self.x2, end_y)
            if self.num_ticks > 0:
                x_pts = linspace(self.x, self.x2,
                                 self.num_ticks + 2).astype(int)
                starts = zeros((len(x_pts), 2), dtype=int)
                starts[:, 0] = x_pts
                starts[:, 1] = bar_y - self._cached_tick_size / 2
                ends = starts.copy()
                ends[:, 1] = bar_y + self._cached_tick_size / 2
                gc.line_set(starts, ends)
        else:
            gc.move_to(bar_x, self.y)
            gc.line_to(bar_x, self.y2)
            if self.endcaps:
                start_x = bar_x - self._cached_endcap_size / 2
                end_x = bar_x + self._cached_endcap_size / 2
                gc.move_to(start_x, self.y)
                gc.line_to(end_x, self.y)
                gc.move_to(start_x, self.y2)
                gc.line_to(end_x, self.y2)
            if self.num_ticks > 0:
                y_pts = linspace(self.y, self.y2,
                                 self.num_ticks + 2).astype(int)
                starts = zeros((len(y_pts), 2), dtype=int)
                starts[:, 1] = y_pts
                starts[:, 0] = bar_x - self._cached_tick_size / 2
                ends = starts.copy()
                ends[:, 0] = bar_x + self._cached_tick_size / 2
                gc.line_set(starts, ends)
        gc.stroke_path()

        # Draw the slider
        pt = self.map_screen(self.value)
        if self.slider == "rect":
            gc.set_fill_color(self.slider_color_)
            gc.set_stroke_color(self.slider_border_)
            gc.set_line_width(self.slider_outline_width)
            rect = self._get_rect_slider_bounds()
            gc.rect(*rect)
            gc.draw_path()
        else:
            self._render_marker(gc, pt, self._cached_slider_size,
                                self.slider_(), self.custom_slider)

    def _get_rect_slider_bounds(self):
        """ Returns the (x, y, w, h) bounds of the rectangle representing the slider.
        Used for rendering and hit detection.
        """
        bar_x = self.x + self.width / 2
        bar_y = self.y + self.height / 2
        pt = self.map_screen(self.value)
        if self.orientation == "h":
            slider_height = self._cached_slider_size
            return (pt[0] - self.slider_thickness, bar_y - slider_height / 2,
                    self.slider_thickness, slider_height)
        else:
            slider_width = self._cached_slider_size
            return (bar_x - slider_width / 2, pt[1] - self.slider_thickness,
                    slider_width, self.slider_thickness)

    def _render_marker(self, gc, point, size, marker, custom_path):
        with gc:
            gc.begin_path()
            if marker.draw_mode == STROKE:
                gc.set_stroke_color(self.slider_color_)
                gc.set_line_width(self.slider_thickness)
            else:
                gc.set_fill_color(self.slider_color_)
                gc.set_stroke_color(self.slider_border_)
                gc.set_line_width(self.slider_outline_width)

            if hasattr(gc, "draw_marker_at_points") and \
                    (marker.__class__ != CustomMarker) and \
                    (gc.draw_marker_at_points([point], size, marker.kiva_marker) != 0):
                pass
            elif hasattr(gc, "draw_path_at_points"):
                if marker.__class__ != CustomMarker:
                    path = gc.get_empty_path()
                    marker.add_to_path(path, size)
                    mode = marker.draw_mode
                else:
                    path = custom_path
                    mode = STROKE
                if not marker.antialias:
                    gc.set_antialias(False)
                gc.draw_path_at_points([point], path, mode)
            else:
                if not marker.antialias:
                    gc.set_antialias(False)
                if marker.__class__ != CustomMarker:
                    gc.translate_ctm(*point)
                    # Kiva GCs have a path-drawing interface
                    marker.add_to_path(gc, size)
                    gc.draw_path(marker.draw_mode)
                else:
                    path = custom_path
                    gc.translate_ctm(*point)
                    gc.add_path(path)
                    gc.draw_path(STROKE)

    #------------------------------------------------------------------------
    # Interaction event handlers
    #------------------------------------------------------------------------

    def normal_left_down(self, event):
        if self.mouse_button == "left":
            return self._mouse_pressed(event)

    def dragging_left_up(self, event):
        if self.mouse_button == "left":
            return self._mouse_released(event)

    def normal_right_down(self, event):
        if self.mouse_button == "right":
            return self._mouse_pressed(event)

    def dragging_right_up(self, event):
        if self.mouse_button == "right":
            return self._mouse_released(event)

    def dragging_mouse_move(self, event):
        dx, dy = self._offset
        self.value = self.map_data(event.x - dx, event.y - dy)
        event.handled = True
        self.request_redraw()

    def dragging_mouse_leave(self, event):
        self.event_state = "normal"

    def _mouse_pressed(self, event):
        # Determine the slider bounds so we can hit test it
        pt = self.map_screen(self.value)
        if self.slider == "rect":
            x, y, w, h = self._get_rect_slider_bounds()
            x2 = x + w
            y2 = y + h
        else:
            x, y = pt
            size = self._cached_slider_size
            x -= size / 2
            y -= size / 2
            x2 = x + size
            y2 = y + size

        # Hit test both the slider and against the bar.  If the user has
        # clicked on the bar but outside of the slider, we set the _offset
        # and call dragging_mouse_move() to teleport the slider to the
        # mouse click position.
        if self.orientation == "v" and (x <= event.x <= x2):
            if not (y <= event.y <= y2):
                self._offset = (event.x - pt[0], 0)
                self.dragging_mouse_move(event)
            else:
                self._offset = (event.x - pt[0], event.y - pt[1])
        elif self.orientation == "h" and (y <= event.y <= y2):
            if not (x <= event.x <= x2):
                self._offset = (0, event.y - pt[1])
                self.dragging_mouse_move(event)
            else:
                self._offset = (event.x - pt[0], event.y - pt[1])
        else:
            # The mouse click missed the bar and the slider.
            return

        event.handled = True
        self.event_state = "dragging"
        return

    def _mouse_released(self, event):
        self.event_state = "normal"
        event.handled = True

    #------------------------------------------------------------------------
    # Private trait event handlers and property getters/setters
    #------------------------------------------------------------------------

    def _get_axis_ndx(self):
        if self.orientation == "h":
            return 0
        else:
            return 1

    def _get_slider_size(self):
        return self._cached_slider_size

    def _get_endcap_size(self):
        return self._cached_endcap_size

    def _get_tick_size(self):
        return self._cached_tick_size

    @on_trait_change("bounds,bounds_items")
    def _update_sizes(self):
        if self._slider_size_mode == "percent":
            if self.orientation == "h":
                self._cached_slider_size = int(self.height *
                                               self._slider_percent)
            else:
                self._cached_slider_size = int(self.width *
                                               self._slider_percent)
        if self._endcap_size_mode == "percent":
            if self.orientation == "h":
                self._cached_endcap_size = int(self.height *
                                               self._endcap_percent)
            else:
                self._cached_endcap_size = int(self.width *
                                               self._endcap_percent)

        return
Ejemplo n.º 17
0
class Loop(HasTraits):
    """ A current loop class.
    """

    #-------------------------------------------------------------------------
    # Public traits
    #-------------------------------------------------------------------------
    direction = Array(float,
                      value=(0, 0, 1),
                      cols=3,
                      shape=(3, ),
                      desc='directing vector of the loop',
                      enter_set=True,
                      auto_set=False)

    radius = CFloat(0.1,
                    desc='radius of the loop',
                    enter_set=True,
                    auto_set=False)

    position = Array(float,
                     value=(0, 0, 0),
                     cols=3,
                     shape=(3, ),
                     desc='position of the center of the loop',
                     enter_set=True,
                     auto_set=False)

    _plot = None

    Bnorm = Property(depends_on='direction,position,radius')

    view = View('position', 'direction', 'radius', '_')

    #-------------------------------------------------------------------------
    # Loop interface
    #-------------------------------------------------------------------------
    def base_vectors(self):
        """ Returns 3 orthognal base vectors, the first one colinear to
            the axis of the loop.
        """
        # normalize n
        n = self.direction / (self.direction**2).sum(axis=-1)

        # choose two vectors perpendicular to n
        # choice is arbitrary since the coil is symetric about n
        if np.abs(n[0]) == 1:
            l = np.r_[n[2], 0, -n[0]]
        else:
            l = np.r_[0, n[2], -n[1]]

        l /= (l**2).sum(axis=-1)
        m = np.cross(n, l)
        return n, l, m

    @on_trait_change('Bnorm')
    def redraw(self):
        if hasattr(self, 'app') and self.app.scene._renderer is not None:
            self.display()
            self.app.visualize_field()

    def display(self):
        """
        Display the coil in the 3D view.
        """
        n, l, m = self.base_vectors()
        theta = np.linspace(0, 2 * np.pi, 30)[..., np.newaxis]
        coil = self.radius * (np.sin(theta) * l + np.cos(theta) * m)
        coil += self.position
        coil_x, coil_y, coil_z = coil.T
        if self._plot is None:
            self._plot = self.app.scene.mlab.plot3d(coil_x,
                                                    coil_y,
                                                    coil_z,
                                                    tube_radius=0.007,
                                                    color=(0, 0, 1),
                                                    name='Coil')
        else:
            self._plot.mlab_source.trait_set(x=coil_x, y=coil_y, z=coil_z)

    def _get_Bnorm(self):
        """
        returns the magnetic field for the current loop calculated
        from eqns (1) and (2) in Phys Rev A Vol. 35, N 4, pp. 1535-1546; 1987.
        """
        ### Translate the coordinates in the coil's frame
        n, l, m = self.base_vectors()
        R = self.radius
        r0 = self.position
        r = np.c_[np.ravel(X), np.ravel(Y), np.ravel(Z)]

        # transformation matrix coil frame to lab frame
        trans = np.vstack((l, m, n))

        r -= r0  #point location from center of coil
        r = np.dot(r, linalg.inv(trans))  #transform vector to coil frame

        #### calculate field

        # express the coordinates in polar form
        x = r[:, 0]
        y = r[:, 1]
        z = r[:, 2]
        rho = np.sqrt(x**2 + y**2)
        theta = np.arctan2(x, y)

        E = special.ellipe((4 * R * rho) / ((R + rho)**2 + z**2))
        K = special.ellipk((4 * R * rho) / ((R + rho)**2 + z**2))
        Bz = 1 / np.sqrt((R + rho)**2 + z**2) * (K + E *
                                                 (R**2 - rho**2 - z**2) /
                                                 ((R - rho)**2 + z**2))
        Brho = z / (rho *
                    np.sqrt((R + rho)**2 + z**2)) * (-K + E *
                                                     (R**2 + rho**2 + z**2) /
                                                     ((R - rho)**2 + z**2))
        # On the axis of the coil we get a divided by zero here. This returns a
        # NaN, where the field is actually zero :
        Brho[np.isnan(Brho)] = 0

        B = np.c_[np.cos(theta) * Brho, np.sin(theta) * Brho, Bz]

        # Rotate the field back in the lab's frame
        B = np.dot(B, trans)

        Bx, By, Bz = B.T
        Bx = np.reshape(Bx, X.shape)
        By = np.reshape(By, X.shape)
        Bz = np.reshape(Bz, X.shape)

        Bnorm = np.sqrt(Bx**2 + By**2 + Bz**2)

        # We need to threshold ourselves, rather than with VTK, to be able
        # to use an ImageData
        Bmax = 10 * np.median(Bnorm)

        Bx[Bnorm > Bmax] = np.NAN
        By[Bnorm > Bmax] = np.NAN
        Bz[Bnorm > Bmax] = np.NAN
        Bnorm[Bnorm > Bmax] = np.NAN

        self.Bx = Bx
        self.By = By
        self.Bz = Bz
        return Bnorm
Ejemplo n.º 18
0
class ThermoSource(BaseSource):
    trap_voltage = Property(depends_on='_trap_voltage')
    _trap_voltage = Float
    trap_current = Property(depends_on='_trap_current')
    _trap_current = Float

    z_symmetry = Property(depends_on='_z_symmetry')
    y_symmetry = Property(depends_on='_y_symmetry')
    extraction_lens = Property(Range(0, 100.0), depends_on='_extraction_lens')
    emission = Float

    _y_symmetry = Float  # Range(0.0, 100.)
    _z_symmetry = Float  # Range(0.0, 100.)

    y_symmetry_low = Float(-100.0)
    y_symmetry_high = Float(100.0)
    z_symmetry_low = Float(-100.0)
    z_symmetry_high = Float(100.0)

    _extraction_lens = Float  # Range(0.0, 100.)

    def set_hv(self, v):
        return self._set_value('SetHV', v)

    def read_emission(self):
        return self._read_value('GetParameter Source Current Readback',
                                'emission')

    def read_trap_current(self):
        return self._read_value('GetParameter Trap Current Readback',
                                '_trap_current')

    def read_y_symmetry(self):
        return self._read_value('GetYSymmetry', '_y_symmetry')

    def read_z_symmetry(self):
        return self._read_value('GetZSymmetry', '_z_symmetry')

    def read_trap_voltage(self):
        return self._read_value('GetParameter Trap Voltage Readback',
                                '_trap_voltage')

    def read_hv(self):
        return self._read_value('GetHighVoltage', 'current_hv')

    def _set_value(self, name, v):
        r = self.ask('{} {}'.format(name, v))
        if r is not None:
            if r.lower().strip() == 'ok':
                return True

    def _read_value(self, name, value):
        r = self.ask(name)
        try:
            r = float('{:0.3f}'.format(float(r)))
            setattr(self, value, r)
            return getattr(self, value)
        except (ValueError, TypeError):
            pass

    def sync_parameters(self):
        self.read_y_symmetry()
        self.read_z_symmetry()
        self.read_trap_current()
        self.read_hv()

    def traits_view(self):
        v = View(
            Item('nominal_hv', format_str='%0.4f'),
            Item('current_hv', format_str='%0.4f', style='readonly'),
            Item('trap_current'), Item('trap_voltage'),
            Item('y_symmetry',
                 editor=RangeEditor(low_name='y_symmetry_low',
                                    high_name='y_symmetry_high',
                                    mode='slider')),
            Item('z_symmetry',
                 editor=RangeEditor(low_name='z_symmetry_low',
                                    high_name='z_symmetry_high',
                                    mode='slider')), Item('extraction_lens'))
        return v

    # ===============================================================================
    # property get/set
    # ===============================================================================
    def _get_trap_voltage(self):
        return self._trap_voltage

    def _get_trap_current(self):
        return self._trap_current

    def _get_y_symmetry(self):
        return self._y_symmetry

    def _get_z_symmetry(self):
        return self._z_symmetry

    def _get_extraction_lens(self):
        return self._extraction_lens

    def _set_trap_voltage(self, v):
        if self._set_value('SetParameter', 'Trap Voltage Set,{}'.format(v)):
            self._trap_current = v

    def _set_trap_current(self, v):
        if self._set_value('SetParameter', 'Trap Current Set,{}'.format(v)):
            self._trap_current = v

    def _set_y_symmetry(self, v):
        if self._set_value('SetYSymmetry', v):
            self._y_symmetry = v

    def _set_z_symmetry(self, v):
        if self._set_value('SetZSymmetry', v):
            self._z_symmetry = v

    def _set_extraction_lens(self, v):
        if self._set_value('SetExtractionLens', v):
            self._extraction_lens = v
Ejemplo n.º 19
0
class SimpleEditor(Editor):
    """ Simple style of editor for sets.

        The editor displays two list boxes, with buttons for moving the selected
        items from left to right, or vice versa. If **can_move_all** on the
        factory is True, then buttons are displayed for moving all the items to
        one box or the other. If the set is ordered, buttons are displayed for
        moving the selected item up or down in right-side list box.
    """

    # -------------------------------------------------------------------------
    #  Trait definitions:
    # -------------------------------------------------------------------------

    #: Current set of enumeration names:
    names = Property()

    #: Current mapping from names to values:
    mapping = Property()

    #: Current inverse mapping from values to names:
    inverse_mapping = Property()

    #: Is set editor scrollable? This value overrides the default.
    scrollable = True

    def init(self, parent):
        """ Finishes initializing the editor by creating the underlying toolkit
            widget.
        """
        factory = self.factory
        if factory.name != "":
            self._object, self._name, self._value = self.parse_extended_name(
                factory.name
            )
            self.values_changed()
            self._object.on_trait_change(
                self._values_changed, self._name, dispatch="ui"
            )
        else:
            self._value = lambda: self.factory.values
            self.values_changed()
            factory.on_trait_change(
                self._values_changed, "values", dispatch="ui"
            )

        self.control = panel = TraitsUIPanel(parent, -1)
        hsizer = wx.BoxSizer(wx.HORIZONTAL)
        vsizer = wx.BoxSizer(wx.VERTICAL)

        self._unused = self._create_listbox(
            panel,
            hsizer,
            self._on_unused,
            self._on_use,
            factory.left_column_title,
        )

        self._use_all = self._unuse_all = self._up = self._down = None
        if factory.can_move_all:
            self._use_all = self._create_button(
                ">>", panel, vsizer, 15, self._on_use_all
            )

        self._use = self._create_button(">", panel, vsizer, 15, self._on_use)
        self._unuse = self._create_button(
            "<", panel, vsizer, 0, self._on_unuse
        )
        if factory.can_move_all:
            self._unuse_all = self._create_button(
                "<<", panel, vsizer, 15, self._on_unuse_all
            )

        if factory.ordered:
            self._up = self._create_button(
                "Move Up", panel, vsizer, 30, self._on_up
            )
            self._down = self._create_button(
                "Move Down", panel, vsizer, 0, self._on_down
            )

        hsizer.Add(vsizer, 0, wx.LEFT | wx.RIGHT, 8)
        self._used = self._create_listbox(
            panel,
            hsizer,
            self._on_value,
            self._on_unuse,
            factory.right_column_title,
        )

        panel.SetSizer(hsizer)

        self.context_object.on_trait_change(
            self.update_editor, self.extended_name + "_items?", dispatch="ui"
        )
        self.set_tooltip()

    def _get_names(self):
        """ Gets the current set of enumeration names.
        """
        return self._names

    def _get_mapping(self):
        """ Gets the current mapping.
        """
        return self._mapping

    def _get_inverse_mapping(self):
        """ Gets the current inverse mapping.
        """
        return self._inverse_mapping

    def _create_listbox(self, parent, sizer, handler1, handler2, title):
        """ Creates a list box.
        """
        column_sizer = wx.BoxSizer(wx.VERTICAL)

        # Add the column title in emphasized text:
        title_widget = wx.StaticText(parent, -1, title)
        font = title_widget.GetFont()
        emphasis_font = wx.Font(
            font.GetPointSize() + 1, font.GetFamily(), font.GetStyle(), wx.BOLD
        )
        title_widget.SetFont(emphasis_font)
        column_sizer.Add(title_widget, 0, 0)

        # Create the list box and add it to the column:
        list = wx.ListBox(parent, -1, style=wx.LB_EXTENDED | wx.LB_NEEDED_SB)
        column_sizer.Add(list, 1, wx.EXPAND)

        # Add the column to the SetEditor widget:
        sizer.Add(column_sizer, 1, wx.EXPAND)

        # Hook up the event handlers:
        parent.Bind(wx.EVT_LISTBOX, handler1, id=list.GetId())
        parent.Bind(wx.EVT_LISTBOX_DCLICK, handler2, id=list.GetId())

        return list

    def _create_button(self, label, parent, sizer, space_before, handler):
        """ Creates a button.
        """
        button = wx.Button(parent, -1, label, style=wx.BU_EXACTFIT)
        sizer.AddSpacer(space_before)
        sizer.Add(button, 0, wx.EXPAND | wx.BOTTOM, 8)
        parent.Bind(wx.EVT_BUTTON, handler, id=button.GetId())
        return button

    def values_changed(self):
        """ Recomputes the cached data based on the underlying enumeration model
            or the values of the factory.
        """
        self._names, self._mapping, self._inverse_mapping = enum_values_changed(
            self._value(), self.string_value
        )

    def _values_changed(self):
        """ Handles the underlying object model's enumeration set or factory's
            values being changed.
        """
        self.values_changed()
        self.update_editor()

    def update_editor(self):
        """ Updates the editor when the object trait changes externally to the
            editor.
        """
        # Check for any items having been deleted from the enumeration that are
        # still present in the object value:
        mapping = self.inverse_mapping.copy()
        values = [v for v in self.value if v in mapping]
        if len(values) < len(self.value):
            self.value = values
            return

        # Get a list of the selected items in the right box:
        used = self._used
        used_labels = self._get_selected_strings(used)

        # Get a list of the selected items in the left box:
        unused = self._unused
        unused_labels = self._get_selected_strings(unused)

        # Empty list boxes in preparation for rebuilding from current values:
        used.Clear()
        unused.Clear()

        # Ensure right list box is kept alphabetized unless insertion
        # order is relevant:
        if not self.factory.ordered:
            values = sorted(values[:])

        # Rebuild the right listbox:
        used_selections = []
        for i, value in enumerate(values):
            label = mapping[value]
            used.Append(label)
            del mapping[value]
            if label in used_labels:
                used_selections.append(i)

        # Rebuild the left listbox:
        unused_selections = []
        unused_items = sorted(mapping.values())
        mapping = self.mapping
        self._unused_items = [mapping[ui] for ui in unused_items]
        for i, unused_item in enumerate(unused_items):
            unused.Append(unused_item)
            if unused_item in unused_labels:
                unused_selections.append(i)

        # If nothing is selected, default selection should be top of left box,
        # or of right box if left box is empty:
        if (len(used_selections) == 0) and (len(unused_selections) == 0):
            if unused.GetCount() == 0:
                used_selections.append(0)
            else:
                unused_selections.append(0)

        used_count = used.GetCount()
        for i in used_selections:
            if i < used_count:
                used.SetSelection(i)

        unused_count = unused.GetCount()
        for i in unused_selections:
            if i < unused_count:
                unused.SetSelection(i)

        self._check_up_down()
        self._check_left_right()

    def dispose(self):
        """ Disposes of the contents of an editor.
        """
        if self._object is not None:
            self._object.on_trait_change(
                self._values_changed, self._name, remove=True
            )
        else:
            self.factory.on_trait_change(
                self._values_changed, "values", remove=True
            )

        self.context_object.on_trait_change(
            self.update_editor, self.extended_name + "_items?", remove=True
        )

        super(SimpleEditor, self).dispose()

    def get_error_control(self):
        """ Returns the editor's control for indicating error status.
        """
        return [self._unused, self._used]

    # -------------------------------------------------------------------------
    #  Event handlers:
    # -------------------------------------------------------------------------

    def _on_value(self, event):
        if not self.factory.ordered:
            self._clear_selection(self._unused)
        self._check_left_right()
        self._check_up_down()

    def _on_unused(self, event):
        if not self.factory.ordered:
            self._clear_selection(self._used)
        self._check_left_right()
        self._check_up_down()

    def _on_use(self, event):
        self._unused_items, self.value = self._transfer_items(
            self._unused, self._used, self._unused_items, self.value
        )

    def _on_unuse(self, event):
        self.value, self._unused_items = self._transfer_items(
            self._used, self._unused, self.value, self._unused_items
        )

    def _on_use_all(self, event):
        self._unused_items, self.value = self._transfer_all(
            self._unused, self._used, self._unused_items, self.value
        )

    def _on_unuse_all(self, event):
        self.value, self._unused_items = self._transfer_all(
            self._used, self._unused, self.value, self._unused_items
        )

    def _on_up(self, event):
        self._move_item(-1)

    def _on_down(self, event):
        self._move_item(1)

    # -------------------------------------------------------------------------
    #  Private methods:
    # -------------------------------------------------------------------------

    # -------------------------------------------------------------------------
    # Unselects all items in the given ListBox
    # -------------------------------------------------------------------------

    def _clear_selection(self, box):
        """ Unselects all items in the given ListBox
        """
        for i in box.GetSelections():
            box.Deselect(i)

    def _transfer_all(self, list_from, list_to, values_from, values_to):
        """ Transfers all items from one list to another.
        """
        values_from = values_from[:]
        values_to = values_to[:]

        self._clear_selection(list_from)
        while list_from.GetCount() > 0:
            index_to = list_to.GetCount()
            list_from.SetSelection(0)
            list_to.InsertItems(
                self._get_selected_strings(list_from), index_to
            )
            list_from.Delete(0)
            values_to.append(values_from[0])
            del values_from[0]

        list_to.SetSelection(0)
        self._check_left_right()
        self._check_up_down()

        return (values_from, values_to)

    def _transfer_items(self, list_from, list_to, values_from, values_to):
        """ Transfers the selected item from one list to another.
        """
        values_from = values_from[:]
        values_to = values_to[:]
        indices_from = list_from.GetSelections()
        index_from = max(self._get_first_selection(list_from), 0)
        index_to = max(self._get_first_selection(list_to), 0)

        self._clear_selection(list_to)

        # Get the list of strings in the "from" box to be moved:
        selected_list = self._get_selected_strings(list_from)

        # fixme: I don't know why I have to reverse the list to get
        # correct behavior from the ordered list box.  Investigate -- LP
        selected_list.reverse()
        list_to.InsertItems(selected_list, index_to)

        # Delete the transferred items from the left box:
        for i in range(len(indices_from) - 1, -1, -1):
            list_from.Delete(indices_from[i])

        # Delete the transferred items from the "unused" value list:
        for item_label in selected_list:
            val_index_from = values_from.index(self.mapping[item_label])
            values_to.insert(index_to, values_from[val_index_from])
            del values_from[val_index_from]

            # If right list is ordered, keep moved items selected:
            if self.factory.ordered:
                list_to.SetSelection(list_to.FindString(item_label))

        # Reset the selection in the left box:
        count = list_from.GetCount()
        if count > 0:
            if index_from >= count:
                index_from -= 1
            list_from.SetSelection(index_from)

        self._check_left_right()
        self._check_up_down()

        return (values_from, values_to)

    def _move_item(self, direction):
        """ Moves an item up or down within the "used" list.
        """
        # Move the item up/down within the list:
        listbox = self._used
        index_from = self._get_first_selection(listbox)
        index_to = index_from + direction
        label = listbox.GetString(index_from)
        listbox.Deselect(index_from)
        listbox.Delete(index_from)
        listbox.Insert(label, index_to)
        listbox.SetSelection(index_to)

        # Enable the up/down buttons appropriately:
        self._check_up_down()

        # Move the item up/down within the editor's trait value:
        value = self.value
        if direction < 0:
            index = index_to
            values = [value[index_from], value[index_to]]
        else:
            index = index_from
            values = [value[index_to], value[index_from]]
        self.value = value[:index] + values + value[index + 2 :]

    def _check_up_down(self):
        """ Sets the proper enabled state for the up and down buttons.
        """
        if self.factory.ordered:
            index_selected = self._used.GetSelections()
            self._up.Enable(
                (len(index_selected) == 1) and (index_selected[0] > 0)
            )
            self._down.Enable(
                (len(index_selected) == 1)
                and (index_selected[0] < (self._used.GetCount() - 1))
            )

    def _check_left_right(self):
        """ Sets the proper enabled state for the left and right buttons.
        """
        self._use.Enable(
            self._unused.GetCount() > 0
            and self._get_first_selection(self._unused) >= 0
        )
        self._unuse.Enable(
            self._used.GetCount() > 0
            and self._get_first_selection(self._used) >= 0
        )

        if self.factory.can_move_all:
            self._use_all.Enable(
                (self._unused.GetCount() > 0)
                and (self._get_first_selection(self._unused) >= 0)
            )
            self._unuse_all.Enable(
                (self._used.GetCount() > 0)
                and (self._get_first_selection(self._used) >= 0)
            )

    # -------------------------------------------------------------------------
    # Returns a list of the selected strings in the listbox
    # -------------------------------------------------------------------------

    def _get_selected_strings(self, listbox):
        """ Returns a list of the selected strings in the given *listbox*.
        """
        stringlist = []
        for label_index in listbox.GetSelections():
            stringlist.append(listbox.GetString(label_index))

        return stringlist

    # -------------------------------------------------------------------------
    # Returns the index of the first (or only) selected item.
    # -------------------------------------------------------------------------

    def _get_first_selection(self, listbox):
        """ Returns the index of the first (or only) selected item.
        """
        select_list = listbox.GetSelections()
        if len(select_list) == 0:
            return -1

        return select_list[0]
Ejemplo n.º 20
0
class BaseEditor(Editor):
    """ Base class for enumeration editors.
    """

    #: Current set of enumeration names:
    names = Property()

    #: Current mapping from names to values:
    mapping = Property()

    #: Current inverse mapping from values to names:
    inverse_mapping = Property()

    # -------------------------------------------------------------------------
    #  BaseEditor Interface
    # -------------------------------------------------------------------------

    def values_changed(self):
        """ Recomputes the cached data based on the underlying enumeration model
            or the values of the factory.
        """
        self._names, self._mapping, self._inverse_mapping = enum_values_changed(
            self._value(), self.string_value)

    def rebuild_editor(self):
        """ Rebuilds the contents of the editor whenever the original factory
            object's **values** trait changes.

            This is not needed for the Qt backends.
        """
        raise NotImplementedError

    # -------------------------------------------------------------------------
    #  Editor Interface
    # -------------------------------------------------------------------------

    def init(self, parent):
        """ Finishes initializing the editor by creating the underlying toolkit
            widget.
        """
        factory = self.factory
        if factory.name != "":
            self._object, self._name, self._value = self.parse_extended_name(
                factory.name)
            self.values_changed()
            self._object.on_trait_change(self._values_changed,
                                         " " + self._name,
                                         dispatch="ui")
        else:
            self._value = lambda: self.factory.values
            self.values_changed()
            factory.on_trait_change(self._values_changed,
                                    "values",
                                    dispatch="ui")

    def dispose(self):
        """ Disposes of the contents of an editor.
        """
        if self._object is not None:
            self._object.on_trait_change(self._values_changed,
                                         " " + self._name,
                                         remove=True)
        else:
            self.factory.on_trait_change(self._values_changed,
                                         "values",
                                         remove=True)

        super(BaseEditor, self).dispose()

    # -------------------------------------------------------------------------
    #  Private interface
    # -------------------------------------------------------------------------

    # Trait default handlers -------------------------------------------------

    def _get_names(self):
        """ Gets the current set of enumeration names.
        """
        return self._names

    def _get_mapping(self):
        """ Gets the current mapping.
        """
        return self._mapping

    def _get_inverse_mapping(self):
        """ Gets the current inverse mapping.
        """
        return self._inverse_mapping

    # Trait change handlers --------------------------------------------------

    def _values_changed(self):
        """ Handles the underlying object model's enumeration set or factory's
            values being changed.
        """
        self.values_changed()
        self.rebuild_editor()
Ejemplo n.º 21
0
class FERefinementLevel(FESubDomain):

    # specialized label
    _tree_label = Str('refinement level')

    def _set_domain(self, value):
        'reset the domain of this domain'
        if self.parent != None:
            raise TraitError, 'child FESubDomain cannot be added to FEDomain'
        super(FERefinementLevel, self)._set_domain(value)

    def _get_domain(self):
        if self.parent != None:
            return self.parent.domain
        return super(FERefinementLevel, self)._get_domain()

    def validate(self):
        if self.parent != None:
            raise ValueError, 'only parentless subdomains can be inserted into domain'

    # children domains: list of the instances of the same class
    children = List(This)

    # parent domain
    _parent = This(domain_changed=True)
    parent = Property(This)

    def _set_parent(self, value):
        'reset the parent of this domain'
        if self._parent:
            # check to see that the changed parent
            # is within the same domain
            if value.domain != self._parent.domain:
                raise NotImplementedError, 'Parent change across domains not implemented'
            # unregister in the old parent
            self._parent.children.remove(self)
        else:
            # append the current subdomain at the end of the subdomain
            # series within the domain
            #pass
            value.domain._append_in_series(self)
        # set the new parent
        self._parent = value
        # register the subdomain in the new parent
        self._parent.children.append(self)

    def _get_parent(self):
        return self._parent

    #---------------------------------------------------------------------------------------
    # Implement the child interface
    #---------------------------------------------------------------------------------------
    # Element refinement representation
    #
    refinement_dict = Dict(changed_structure=True)

    def refine_elem(self, parent_ix, *refinement_args):
        '''For the specified parent position let the new element decompose.
        '''
        if self.refinement_dict.has_key(parent_ix):
            raise ValueError, 'element %s already refined' % ` parent_ix `

        # the element is deactivated in the parent domain
        self.refinement_dict[parent_ix] = refinement_args
        self.parent.deactivate(parent_ix)
class DataRange1D(BaseDataRange):
    """ Represents a 1-D data range.
    """

    # The actual value of the lower bound of this range (overrides
    # AbstractDataRange). To set it, use **low_setting**.
    low = Property
    # The actual value of the upper bound of this range (overrides
    # AbstractDataRange). To set it, use **high_setting**.
    high = Property

    # Property for the lower bound of this range (overrides AbstractDataRange).
    #
    # * 'auto': The lower bound is automatically set at or below the minimum
    #   of the data.
    # * 'track': The lower bound tracks the upper bound by **tracking_amount**.
    # * CFloat: An explicit value for the lower bound
    low_setting = Property(Trait('auto', 'auto', 'track', CFloat))
    # Property for the upper bound of this range (overrides AbstractDataRange).
    #
    # * 'auto': The upper bound is automatically set at or above the maximum
    #   of the data.
    # * 'track': The upper bound tracks the lower bound by **tracking_amount**.
    # * CFloat: An explicit value for the upper bound
    high_setting = Property(Trait('auto', 'auto', 'track', CFloat))

    # Do "auto" bounds imply an exact fit to the data? If False,
    # they pad a little bit of margin on either side.
    tight_bounds = Bool(True)

    # A user supplied function returning the proper bounding interval.
    # bounds_func takes (data_low, data_high, margin, tight_bounds)
    # and returns (low, high)
    bounds_func = Callable

    # The amount of margin to place on either side of the data, expressed as
    # a percentage of the full data width
    margin = Float(0.05)

    # The minimum percentage difference between low and high.  That is,
    # (high-low) >= epsilon * low.
    # Used to be 1.0e-20 but chaco cannot plot at such a precision!
    epsilon = CFloat(1.0e-10)

    # When either **high** or **low** tracks the other, track by this amount.
    default_tracking_amount = CFloat(20.0)

    # The current tracking amount. This value changes with zooming.
    tracking_amount = default_tracking_amount

    # Default tracking state. This value is used when self.reset() is called.
    #
    # * 'auto': Both bounds reset to 'auto'.
    # * 'high_track': The high bound resets to 'track', and the low bound
    #   resets to 'auto'.
    # * 'low_track': The low bound resets to 'track', and the high bound
    #   resets to 'auto'.
    default_state = Enum('auto', 'high_track', 'low_track')

    # FIXME: this attribute is not used anywhere, is it safe to remove it?
    # Is this range dependent upon another range?
    fit_to_subset = Bool(False)

    #------------------------------------------------------------------------
    # Private traits
    #------------------------------------------------------------------------

    # The "_setting" attributes correspond to what the user has "set"; the
    # "_value" attributes are the actual numerical values for the given
    # setting.

    # The user-specified low setting.
    _low_setting = Trait('auto', 'auto', 'track', CFloat)
    # The actual numerical value for the low setting.
    _low_value = CFloat(-inf)
    # The user-specified high setting.
    _high_setting = Trait('auto', 'auto', 'track', CFloat)
    # The actual numerical value for the high setting.
    _high_value = CFloat(inf)

    # A list of attributes to persist
    # _pickle_attribs = ("_low_setting", "_high_setting")

    #------------------------------------------------------------------------
    # AbstractRange interface
    #------------------------------------------------------------------------

    def clip_data(self, data):
        """ Returns a list of data values that are within the range.

        Implements AbstractDataRange.
        """
        return compress(self.mask_data(data), data)

    def mask_data(self, data):
        """ Returns a mask array, indicating whether values in the given array
        are inside the range.

        Implements AbstractDataRange.
        """
        return ((data.view(ndarray) >= self._low_value) &
                (data.view(ndarray) <= self._high_value))

    def bound_data(self, data):
        """ Returns a tuple of indices for the start and end of the first run
        of *data* that falls within the range.

        Implements AbstractDataRange.
        """
        mask = self.mask_data(data)
        runs = arg_find_runs(mask, "flat")
        # Since runs of "0" are also considered runs, we have to cycle through
        # until we find the first run of "1"s.
        for run in runs:
            if mask[run[0]] == 1:
                # arg_find_runs returns 1 past the end
                return run[0], run[1] - 1
        return (0, 0)

    def set_bounds(self, low, high):
        """ Sets all the bounds of the range simultaneously.

        Implements AbstractDataRange.
        """
        if low == 'track':
            # Set the high setting first
            result_high = self._do_set_high_setting(high, fire_event=False)
            result_low = self._do_set_low_setting(low, fire_event=False)
            result = result_low or result_high
        else:
            # Either set low first or order doesn't matter
            result_low = self._do_set_low_setting(low, fire_event=False)
            result_high = self._do_set_high_setting(high, fire_event=False)
            result = result_high or result_low
        if result:
            self.updated = result

    def scale_tracking_amount(self, multiplier):
        """ Sets the **tracking_amount** to a new value, scaled by *multiplier*.
        """
        self.tracking_amount = self.tracking_amount * multiplier
        self._do_track()

    def set_tracking_amount(self, amount):
        """ Sets the **tracking_amount** to a new value, *amount*.
        """
        self.tracking_amount = amount
        self._do_track()

    def set_default_tracking_amount(self, amount):
        """ Sets the **default_tracking_amount** to a new value, *amount*.
        """
        self.default_tracking_amount = amount

    #------------------------------------------------------------------------
    # Public methods
    #------------------------------------------------------------------------

    def reset(self):
        """ Resets the bounds of this range, based on **default_state**.
        """
        # need to maintain 'track' setting
        if self.default_state == 'auto':
            self._high_setting = 'auto'
            self._low_setting = 'auto'
        elif self.default_state == 'low_track':
            self._high_setting = 'auto'
            self._low_setting = 'track'
        elif self.default_state == 'high_track':
            self._high_setting = 'track'
            self._low_setting = 'auto'
        self._refresh_bounds()
        self.tracking_amount = self.default_tracking_amount

    def refresh(self):
        """ If any of the bounds is 'auto', this method refreshes the actual
        low and high values from the set of the view filters' data sources.
        """
        if ('auto' in (self._low_setting, self._high_setting)) or \
            ('track' in (self._low_setting, self._high_setting)):
            # If the user has hard-coded bounds, then refresh() doesn't do
            # anything.
            self._refresh_bounds()
        else:
            return

    #------------------------------------------------------------------------
    # Private methods (getters and setters)
    #------------------------------------------------------------------------

    def _get_low(self):
        return float(self._low_value)

    def _set_low(self, val):
        return self._set_low_setting(val)

    def _get_low_setting(self):
        return self._low_setting

    def _do_set_low_setting(self, val, fire_event=True):
        """
        Returns
        -------
        If fire_event is False and the change would have fired an event, returns
        the tuple of the new low and high values.  Otherwise returns None.  In
        particular, if fire_event is True, it always returns None.
        """
        new_values = None
        if self._low_setting != val:

            # Save the new setting.
            self._low_setting = val

            # If val is 'auto' or 'track', get the corresponding numerical
            # value.
            if val == 'auto':
                if len(self.sources) > 0:
                    val = min(
                        [source.get_bounds()[0] for source in self.sources])
                else:
                    val = -inf
            elif val == 'track':
                if len(self.sources) > 0 or self._high_setting != 'auto':
                    val = self._high_value - self.tracking_amount
                else:
                    val = -inf

            # val is now a numerical value.  If it is the same as the current
            # value, there is nothing to do.
            if self._low_value != val:
                self._low_value = val
                if self._high_setting == 'track':
                    self._high_value = val + self.tracking_amount
                if fire_event:
                    self.updated = (self._low_value, self._high_value)
                else:
                    new_values = (self._low_value, self._high_value)

        return new_values

    def _set_low_setting(self, val):
        self._do_set_low_setting(val, True)

    def _get_high(self):
        return float(self._high_value)

    def _set_high(self, val):
        return self._set_high_setting(val)

    def _get_high_setting(self):
        return self._high_setting

    def _do_set_high_setting(self, val, fire_event=True):
        """
        Returns
        -------
        If fire_event is False and the change would have fired an event, returns
        the tuple of the new low and high values.  Otherwise returns None.  In
        particular, if fire_event is True, it always returns None.
        """
        new_values = None
        if self._high_setting != val:

            # Save the new setting.
            self._high_setting = val

            # If val is 'auto' or 'track', get the corresponding numerical
            # value.
            if val == 'auto':
                if len(self.sources) > 0:
                    val = max(
                        [source.get_bounds()[1] for source in self.sources])
                else:
                    val = inf
            elif val == 'track':
                if len(self.sources) > 0 or self._low_setting != 'auto':
                    val = self._low_value + self.tracking_amount
                else:
                    val = inf

            # val is now a numerical value.  If it is the same as the current
            # value, there is nothing to do.
            if self._high_value != val:
                self._high_value = val
                if self._low_setting == 'track':
                    self._low_value = val - self.tracking_amount
                if fire_event:
                    self.updated = (self._low_value, self._high_value)
                else:
                    new_values = (self._low_value, self._high_value)

        return new_values

    def _set_high_setting(self, val):
        self._do_set_high_setting(val, True)

    def _refresh_bounds(self):
        null_bounds = False
        if len(self.sources) == 0:
            null_bounds = True
        else:
            bounds_list = [source.get_bounds() for source in self.sources \
                              if source.get_size() > 0]

            if len(bounds_list) == 0:
                null_bounds = True

        if null_bounds:
            # If we have no sources and our settings are "auto", then reset our
            # bounds to infinity; otherwise, set the _value to the corresponding
            # setting.
            if (self._low_setting in ("auto", "track")):
                self._low_value = -inf
            else:
                self._low_value = self._low_setting
            if (self._high_setting in ("auto", "track")):
                self._high_value = inf
            else:
                self._high_value = self._high_setting
            return
        else:
            mins, maxes = list(zip(*bounds_list))

            low_start, high_start = \
                     calc_bounds(self._low_setting, self._high_setting,
                                 mins, maxes, self.epsilon,
                                 self.tight_bounds, margin=self.margin,
                                 track_amount=self.tracking_amount,
                                 bounds_func=self.bounds_func)

        if (self._low_value != low_start) or (self._high_value != high_start):
            self._low_value = low_start
            self._high_value = high_start
            self.updated = (self._low_value, self._high_value)
        return

    def _do_track(self):
        changed = False
        if self._low_setting == 'track':
            new_value = self._high_value - self.tracking_amount
            if self._low_value != new_value:
                self._low_value = new_value
                changed = True
        elif self._high_setting == 'track':
            new_value = self._low_value + self.tracking_amount
            if self._high_value != new_value:
                self._high_value = new_value
                changed = True
        if changed:
            self.updated = (self._low_value, self._high_value)

    #------------------------------------------------------------------------
    # Event handlers
    #------------------------------------------------------------------------

    def _sources_items_changed(self, event):
        self.refresh()
        for source in event.removed:
            source.on_trait_change(self.refresh, "data_changed", remove=True)
        for source in event.added:
            source.on_trait_change(self.refresh, "data_changed")

    def _sources_changed(self, old, new):
        self.refresh()
        for source in old:
            source.on_trait_change(self.refresh, "data_changed", remove=True)
        for source in new:
            source.on_trait_change(self.refresh, "data_changed")

    #------------------------------------------------------------------------
    # Serialization interface
    #------------------------------------------------------------------------

    def _post_load(self):
        self._sources_changed(None, self.sources)
Ejemplo n.º 23
0
class ToolkitEditorFactory(EditorFactory):
    """ Editor factory for buttons.
    """

    # -------------------------------------------------------------------------
    #  Trait definitions:
    # -------------------------------------------------------------------------

    # Value to set when the button is clicked
    value = Property()

    # Optional label for the button
    label = Str()

    # The name of the external object trait that the button label is synced to
    label_value = Str()

    # The name of the trait on the object that contains the list of possible
    # values.  If this is set, then the value, label, and label_value traits
    # are ignored; instead, they will be set from this list.  When this button
    # is clicked, the value set will be the one selected from the drop-down.
    values_trait = Either(None, Str)

    # (Optional) Image to display on the button
    image = Image

    # Extra padding to add to both the left and the right sides
    width_padding = Range(0, 31, 7)

    # Extra padding to add to both the top and the bottom sides
    height_padding = Range(0, 31, 5)

    # Presentation style
    style = Enum("button", "radio", "toolbar", "checkbox")

    # Orientation of the text relative to the image
    orientation = Enum("vertical", "horizontal")

    # The optional view to display when the button is clicked:
    view = AView

    # -------------------------------------------------------------------------
    #  Traits view definition:
    # -------------------------------------------------------------------------

    traits_view = View(["label", "value", "|[]"])

    def _get_value(self):
        return self._value

    def _set_value(self, value):
        self._value = value
        if isinstance(value, str):
            try:
                self._value = int(value)
            except:
                try:
                    self._value = float(value)
                except:
                    pass

    def __init__(self, **traits):
        self._value = 0
        super(ToolkitEditorFactory, self).__init__(**traits)
Ejemplo n.º 24
0
class ScatterPlot(BaseXYPlot):
    """
    Renders a scatter plot, given an index and value arrays.
    """

    # The CompiledPath to use if **marker** is set to "custom". This attribute
    # must be a compiled path for the Kiva context onto which this plot will
    # be rendered.  Usually, importing kiva.GraphicsContext will do
    # the right thing.
    custom_symbol = Any

    #------------------------------------------------------------------------
    # Styles on a ScatterPlot
    #------------------------------------------------------------------------

    # The type of marker to use.  This is a mapped trait using strings as the
    # keys.
    marker = MarkerTrait

    # The pixel size of the markers, not including the thickness of the outline.
    # Default value is 4.0.
    # TODO: for consistency, there should be a size data source and a mapper
    marker_size = Either(Float, Array)

    # The function which actually renders the markers
    render_markers_func = Callable(render_markers)

    # The thickness, in pixels, of the outline to draw around the marker.  If
    # this is 0, no outline is drawn.
    line_width = Float(1.0)

    # The fill color of the marker.
    color = black_color_trait

    # The color of the outline to draw around the marker.
    outline_color = black_color_trait

    # The RGBA tuple for rendering lines.  It is always a tuple of length 4.
    # It has the same RGB values as color_, and its alpha value is the alpha
    # value of self.color multiplied by self.alpha.
    effective_color = Property(Tuple, depends_on=['color', 'alpha'])

    # The RGBA tuple for rendering the fill.  It is always a tuple of length 4.
    # It has the same RGB values as outline_color_, and its alpha value is the
    # alpha value of self.outline_color multiplied by self.alpha.
    effective_outline_color = Property(Tuple,
                                       depends_on=['outline_color', 'alpha'])

    # Traits UI View for customizing the plot.
    traits_view = ScatterPlotView()

    #------------------------------------------------------------------------
    # Selection and selection rendering
    # A selection on the lot is indicated by setting the index or value
    # datasource's 'selections' metadata item to a list of indices, or the
    # 'selection_mask' metadata to a boolean array of the same length as the
    # datasource.
    #------------------------------------------------------------------------

    show_selection = Bool(True)

    selection_marker = MarkerTrait

    selection_marker_size = Float(4.0)

    selection_line_width = Float(1.0)

    selection_color = ColorTrait("yellow")

    selection_outline_color = black_color_trait

    #------------------------------------------------------------------------
    # Private traits
    #------------------------------------------------------------------------

    _cached_selected_pts = Trait(None, None, Array)
    _cached_selected_screen_pts = Array
    _cached_point_mask = Array
    _cached_selection_point_mask = Array
    _selection_cache_valid = Bool(False)

    #------------------------------------------------------------------------
    # Overridden PlotRenderer methods
    #------------------------------------------------------------------------

    def map_screen(self, data_array):
        """ Maps an array of data points into screen space and returns it as
        an array.

        Implements the AbstractPlotRenderer interface.
        """
        # data_array is Nx2 array
        if len(data_array) == 0:
            return []

        # XXX: For some reason, doing the tuple unpacking doesn't work:
        #        x_ary, y_ary = transpose(data_array)
        # There is a mysterious error "object of too small depth for
        # desired array".  However, if you catch this exception and
        # try to execute the very same line of code again, it works
        # without any complaints.
        #
        # For now, we just use slicing to assign the X and Y arrays.
        data_array = asarray(data_array)
        if len(data_array.shape) == 1:
            x_ary = data_array[0]
            y_ary = data_array[1]
        else:
            x_ary = data_array[:, 0]
            y_ary = data_array[:, 1]

        sx = self.index_mapper.map_screen(x_ary)
        sy = self.value_mapper.map_screen(y_ary)
        if self.orientation == "h":
            return transpose(array((sx, sy)))
        else:
            return transpose(array((sy, sx)))

    def map_data(self, screen_pt, all_values=True):
        """ Maps a screen space point into the "index" space of the plot.

        Overrides the BaseXYPlot implementation, and always returns an
        array of (index, value) tuples.
        """
        x, y = screen_pt
        if self.orientation == 'v':
            x, y = y, x
        return array(
            (self.index_mapper.map_data(x), self.value_mapper.map_data(y)))

    def map_index(self, screen_pt, threshold=0.0, outside_returns_none=True, \
                  index_only = False):
        """ Maps a screen space point to an index into the plot's index array(s).

        Overrides the BaseXYPlot implementation..
        """
        if index_only and self.index.sort_order != "none":
            data_pt = self.map_data(screen_pt)[0]
            # The rest of this was copied out of BaseXYPlot.
            # We can't just used BaseXYPlot.map_index because
            # it expect map_data to return a value, not a pair.
            if ((data_pt < self.index_mapper.range.low) or \
                (data_pt > self.index_mapper.range.high)) and outside_returns_none:
                return None
            index_data = self.index.get_data()
            value_data = self.value.get_data()

            if len(value_data) == 0 or len(index_data) == 0:
                return None

            try:
                ndx = reverse_map_1d(index_data, data_pt,
                                     self.index.sort_order)
            except IndexError, e:
                # if reverse_map raises this exception, it means that data_pt is
                # outside the range of values in index_data.
                if outside_returns_none:
                    return None
                else:
                    if data_pt < index_data[0]:
                        return 0
                    else:
                        return len(index_data) - 1

            if threshold == 0.0:
                # Don't do any threshold testing
                return ndx

            x = index_data[ndx]
            y = value_data[ndx]
            if isnan(x) or isnan(y):
                return None
            sx, sy = self.map_screen([x, y])
            if ((threshold == 0.0) or (screen_pt[0] - sx) < threshold):
                return ndx
            else:
                return None
        else:
Ejemplo n.º 25
0
class TableFilterEditor(HasTraits):
    """ An editor that manages table filters.
    """

    #-------------------------------------------------------------------------
    #  Trait definitions:
    #-------------------------------------------------------------------------

    # TableEditor this editor is associated with
    editor = Instance(TableEditor)

    # The list of filters
    filters = List(TableFilter)

    # The list of available templates from which filters can be created
    templates = Property(List(TableFilter), depends_on='filters')

    # The currently selected filter template
    selected_template = Instance(TableFilter)

    # The currently selected filter
    selected_filter = Instance(TableFilter, allow_none=True)

    # The view to use for the current filter
    selected_filter_view = Property(depends_on='selected_filter')

    # Buttons for add/removing filters
    add_button = Button('New')
    remove_button = Button('Delete')

    # The default view for this editor
    view = View(Group(Group(Group(Item('add_button',
                                       enabled_when='selected_template'),
                                  Item('remove_button',
                                       enabled_when='len(templates) > 1 and '
                                       'selected_filter is not None'),
                                  orientation='horizontal',
                                  show_labels=False),
                            Label('Base filter for new filters:'),
                            Item('selected_template',
                                 editor=EnumEditor(name='templates')),
                            Item('selected_filter',
                                 style='custom',
                                 editor=EnumEditor(name='filters',
                                                   mode='list')),
                            show_labels=False),
                      Item('selected_filter',
                           width=0.75,
                           style='custom',
                           editor=InstanceEditor(view_name='selected_filter_view')),
                      id='TableFilterEditorSplit',
                      show_labels=False,
                      layout='split',
                      orientation='horizontal'),
                id='traitsui.qt4.table_editor.TableFilterEditor',
                buttons=['OK', 'Cancel'],
                kind='livemodal',
                resizable=True, width=800, height=400,
                title='Customize filters')

    #-------------------------------------------------------------------------
    #  Private methods:
    #-------------------------------------------------------------------------

    #-- Trait Property getter/setters ----------------------------------------

    @cached_property
    def _get_selected_filter_view(self):
        view = None
        if self.selected_filter:
            model = self.editor.model
            index = model.mapToSource(model.index(0, 0))
            if index.isValid():
                obj = self.editor.items()[index.row()]
            else:
                obj = None
            view = self.selected_filter.edit_view(obj)
        return view

    @cached_property
    def _get_templates(self):
        templates = [f for f in self.editor.factory.filters if f.template]
        templates.extend(self.filters)
        return templates

    #-- Trait Change Handlers ------------------------------------------------

    def _editor_changed(self):
        self.filters = [f.clone_traits() for f in self.editor.factory.filters
                        if not f.template]
        self.selected_template = self.templates[0]

    def _add_button_fired(self):
        """ Create a new filter based on the selected template and select it.
        """
        new_filter = self.selected_template.clone_traits()
        new_filter.template = False
        new_filter.name = new_filter._name = 'New filter'
        self.filters.append(new_filter)
        self.selected_filter = new_filter

    def _remove_button_fired(self):
        """ Delete the currently selected filter.
        """
        if self.selected_template == self.selected_filter:
            self.selected_template = self.templates[0]

        index = self.filters.index(self.selected_filter)
        del self.filters[index]
        if index < len(self.filters):
            self.selected_filter = self.filters[index]
        else:
            self.selected_filter = None

    @on_trait_change('selected_filter:name')
    def _update_filter_list(self):
        """ A hack to make the EnumEditor watching the list of filters refresh
            their text when the name of the selected filter changes.
        """
        filters = self.filters
        self.filters = []
        self.filters = filters
Ejemplo n.º 26
0
class InterpolationEditor(GraphEditor):
    tool_klass = InterpolationFitSelector
    references = List

    auto_find = Bool(True)
    show_current = Bool(True)

    default_reference_analysis_type = 'air'
    sorted_analyses = Property(depends_on='analyses[]')
    sorted_references = Property(depends_on='references[]')
    binned_analyses = List
    bounds = List
    calculate_reference_age = Bool(False)

    _normalization_factor = 3600.

    def _get_min_max(self):
        mi = min(self.sorted_references[0].timestamp,
                 self.sorted_analyses[0].timestamp)
        ma = max(self.sorted_references[-1].timestamp,
                 self.sorted_analyses[-1].timestamp)
        return mi, ma

    def bin_analyses(self):
        groups = list(bin_analyses(self.analyses))
        self.binned_analyses = []
        self.bounds = []
        n = len(groups)
        if n > 1:
            mi, ma = self._get_min_max()
            mi -= 1
            ma += 1

            bounds = get_bounds(groups)
            self.bounds = map(lambda x: x - mi, bounds)

            gs = []
            low = None
            for i, gi in enumerate(groups):
                if low is None:
                    low = mi
                try:
                    high = bounds[i]
                except IndexError:
                    high = ma

                refs = filter(lambda x: low < x.timestamp < high,
                              self.sorted_references)
                gs.append(
                    BinGroup(unknowns=gi,
                             references=refs,
                             bounds=((low - mi) / self._normalization_factor,
                                     (high - mi) / self._normalization_factor,
                                     i == 0, i == n - 1)))
                low = high

            self.binned_analyses = gs
            self.rebuild_graph()

    def rebuild_graph(self):
        super(InterpolationEditor, self).rebuild_graph()
        if self.bounds:
            for bi in self.bounds:
                self.add_group_divider(bi)

    def add_group_divider(self, cen):
        self.graph.add_vertical_rule(cen / self._normalization_factor,
                                     line_width=1.5,
                                     color='lightblue',
                                     line_style='solid')
        self.graph.redraw()

    def find_references(self):
        self._find_references()

    @on_trait_change('references[]')
    def _update_references(self):
        self._update_references_hook()
        self.rebuild_graph()

    def _update_references_hook(self):
        pass

    def _get_start_end(self, rxs, uxs):
        mrxs = min(rxs) if rxs else Inf
        muxs = min(uxs) if uxs else Inf

        marxs = max(rxs) if rxs else -Inf
        mauxs = max(uxs) if uxs else -Inf

        start = min(mrxs, muxs)
        end = max(marxs, mauxs)
        return start, end

    def set_auto_find(self, f):
        self.auto_find = f

    def _update_analyses_hook(self):
        self.debug('update analyses hook auto_find={}'.format(self.auto_find))
        if self.auto_find:
            self._find_references()

    def set_references(self, refs, is_append=False, **kw):
        ans = self.processor.make_analyses(
            refs,
            # calculate_age=self.calculate_reference_age,
            # unpack=self.unpack_peaktime,
            **kw)

        if is_append:
            pans = self.references
            pans.extend(ans)
            ans = pans

        self.references = ans

    def _find_references(self):
        refs = self.processor.find_references()
        self.references = refs

    def set_interpolated_values(self, iso, reg, ans):
        mi, ma = self._get_min_max()
        if ans is None:
            ans = self.sorted_analyses

        xs = [(ai.timestamp - mi) / self._normalization_factor for ai in ans]

        p_uys = reg.predict(xs)
        p_ues = reg.predict_error(xs)
        self._set_interpolated_values(iso, ans, p_uys, p_ues)
        return p_uys, p_ues

    def _set_interpolated_values(self, *args, **kw):
        pass

    def _get_current_values(self, *args, **kw):
        pass

    def _get_reference_values(self, *args, **kw):
        pass

    def _get_isotope(self, ui, k, kind=None):
        if k in ui.isotopes:
            v = ui.isotopes[k]
            if kind is not None:
                v = getattr(v, kind)
            v = v.value, v.error
        else:
            v = 0, 0
        return v

    def _rebuild_graph(self):
        graph = self.graph

        uxs = [ui.timestamp for ui in self.analyses]
        rxs = [ui.timestamp for ui in self.references]
        # display_xs = asarray(map(convert_timestamp, rxs[:]))

        start, end = self._get_start_end(rxs, uxs)

        c_uxs = self.normalize(uxs, start)
        r_xs = self.normalize(rxs, start)
        '''
            c_... current value
            r... reference value
            p_... predicted value
        '''
        gen = self._graph_generator()
        for i, fit in enumerate(gen):
            iso = fit.name
            fit = fit.fit_tuple()
            # print i, fit, self.binned_analyses
            if self.binned_analyses:
                self._build_binned(i, iso, fit, start)
            else:
                self._build_non_binned(i, iso, fit, c_uxs, r_xs)

            if i == 0:
                self._add_legend()

        m = abs(end - start) / self._normalization_factor
        graph.set_x_limits(0, m, pad='0.1')
        graph.refresh()

    def _build_binned(self, i, iso, fit, start):
        # graph=self.graph
        # iso = fit.name
        # fit = fit.fit.lower()
        # set_x_flag = True
        graph = self.graph
        p = graph.new_plot(ytitle=iso,
                           xtitle='Time (hrs)',
                           padding=[80, 10, 5, 30])
        p.y_axis.title_spacing = 60
        p.value_range.tight_bounds = False

        for j, gi in enumerate(self.binned_analyses):
            bis = gi.unknowns
            refs = gi.references

            c_xs = self.normalize([bi.timestamp for bi in bis], start)
            rxs = [bi.timestamp for bi in refs]
            r_xs = self.normalize(rxs, start)
            dx = asarray(map(convert_timestamp, rxs))

            c_ys, c_es = None, None
            if self.show_current:
                c_ys, c_es = self._get_current_values(iso, bis)

            r_ys, r_es = None, None
            if refs:
                r_ys, r_es = self._get_reference_values(iso, refs)

            current_args = bis, c_xs, c_ys, c_es
            ref_args = refs, r_xs, r_ys, r_es, dx
            self._build_plot(i,
                             iso,
                             fit,
                             current_args,
                             ref_args,
                             series_id=j,
                             regression_bounds=gi.bounds)

    def _build_non_binned(self, i, iso, fit, c_xs, r_xs):

        c_ys, c_es = None, None
        if self.analyses and self.show_current:
            c_ys, c_es = self._get_current_values(iso)

        r_ys, r_es, dx = None, None, None
        if self.references:
            r_ys, r_es = self._get_reference_values(iso)
            dx = asarray(
                map(convert_timestamp,
                    [ui.timestamp for ui in self.references]))

        current_args = self.sorted_analyses, c_xs, c_ys, c_es
        ref_args = self.sorted_references, r_xs, r_ys, r_es, dx

        graph = self.graph
        p = graph.new_plot(
            ytitle=iso,
            xtitle='Time (hrs)',
            # padding=[80, 10, 5, 30]
            padding=[80, 80, 80, 80])
        p.y_axis.title_spacing = 60
        p.value_range.tight_bounds = False

        self._build_plot(i, iso, fit, current_args, ref_args)

    def _plot_unknowns_current(self, ans, c_es, c_xs, c_ys, i):
        graph = self.graph
        if c_es and c_ys:
            # plot unknowns
            s, _p = graph.new_series(c_xs,
                                     c_ys,
                                     yerror=c_es,
                                     fit=False,
                                     type='scatter',
                                     plotid=i,
                                     marker='square',
                                     marker_size=3,
                                     bind_id=-1,
                                     color='black',
                                     add_inspector=False)
            self._add_inspector(s, ans)
            self._add_error_bars(s, c_es)

            graph.set_series_label('Unknowns-Current', plotid=i)

    def _plot_interpolated(self, ans, c_xs, i, iso, reg, series_id):
        p_uys, p_ues = self.set_interpolated_values(iso, reg, ans)
        if len(p_uys):
            graph = self.graph
            # display the predicted values
            s, p = graph.new_series(c_xs,
                                    p_uys,
                                    isotope=iso,
                                    yerror=ArrayDataSource(p_ues),
                                    fit=False,
                                    add_tools=False,
                                    add_inspector=False,
                                    type='scatter',
                                    marker_size=3,
                                    color='blue',
                                    plotid=i,
                                    bind_id=-1)
            series = len(p.plots) - 1
            graph.set_series_label('Unknowns-predicted{}'.format(series_id),
                                   plotid=i,
                                   series=series)

            self._add_error_bars(s, p_ues)

    def _plot_references(self, ref_args, fit, i, regression_bounds, series_id):
        refs, r_xs, r_ys, r_es, display_xs = ref_args
        graph = self.graph
        if not r_ys:
            return

        # plot references
        efit = fit[0]
        if efit in [
                'preceding', 'bracketing interpolate', 'bracketing average'
        ]:
            reg = InterpolationRegressor(xs=r_xs,
                                         ys=r_ys,
                                         yserr=r_es,
                                         kind=efit)
            s, _p = graph.new_series(r_xs,
                                     r_ys,
                                     yerror=r_es,
                                     type='scatter',
                                     plotid=i,
                                     fit=False,
                                     marker_size=3,
                                     color='red',
                                     add_inspector=False)
            self._add_inspector(s, refs)
            self._add_error_bars(s, r_es)
            # series_id = (series_id+1) * 2
        else:

            # series_id = (series_id+1) * 3
            _p, s, l = graph.new_series(
                r_xs,
                r_ys,
                display_index=ArrayDataSource(data=display_xs),
                yerror=ArrayDataSource(data=r_es),
                fit=fit,
                color='red',
                plotid=i,
                marker_size=3,
                add_inspector=False)
            if hasattr(l, 'regressor'):
                reg = l.regressor

            l.regression_bounds = regression_bounds

            self._add_inspector(s, refs)
            self._add_error_bars(s, array(r_es))

        return reg

    def _build_plot(self,
                    i,
                    iso,
                    fit,
                    current_args,
                    ref_args,
                    series_id=0,
                    regression_bounds=None):
        ans, c_xs, c_ys, c_es = current_args

        self._plot_unknowns_current(ans, c_es, c_xs, c_ys, i)
        reg = self._plot_references(ref_args, fit, i, regression_bounds,
                                    series_id)
        if reg:
            self._plot_interpolated(ans, c_xs, i, iso, reg, series_id)

            # self._plot_unknowns_current(ans, c_es, c_xs, c_ys, i)

    def _add_legend(self):
        pass

    def _add_error_bars(self,
                        scatter,
                        errors,
                        orientation='y',
                        visible=True,
                        nsigma=1,
                        line_width=1):
        from pychron.graph.error_bar_overlay import ErrorBarOverlay

        ebo = ErrorBarOverlay(component=scatter,
                              orientation=orientation,
                              nsigma=nsigma,
                              line_width=line_width,
                              visible=visible)

        scatter.underlays.append(ebo)
        # print len(errors),scatter.index.get_size()
        setattr(scatter, '{}error'.format(orientation),
                ArrayDataSource(errors))
        return ebo

    def _add_inspector(self, scatter, ans):
        broadcaster = BroadcasterTool()
        scatter.tools.append(broadcaster)

        rect_tool = RectSelectionTool(scatter)
        rect_overlay = RectSelectionOverlay(component=scatter, tool=rect_tool)

        scatter.overlays.append(rect_overlay)
        broadcaster.tools.append(rect_tool)

        point_inspector = AnalysisPointInspector(
            scatter,
            value_format=floatfmt,
            analyses=ans,
            convert_index=lambda x: '{:0.3f}'.format(x))

        pinspector_overlay = PointInspectorOverlay(component=scatter,
                                                   tool=point_inspector)

        scatter.overlays.append(pinspector_overlay)
        scatter.tools.append(point_inspector)
        broadcaster.tools.append(point_inspector)

        scatter.index.on_trait_change(self._update_metadata(ans),
                                      'metadata_changed')
        # scatter.index.on_trait_change(self._test, 'metadata_changed')

    # def _test(self, obj, name, old, new):
    #     import inspect
    #     stack = inspect.stack()
    #     print stack
    #     print '{} called by {}'.format('test', stack[1][3])
    #     meta = obj.metadata
    #     print meta

    def _update_metadata(self, ans):
        def _handler(obj, name, old, new):
            meta = obj.metadata
            selections = meta['selections']
            for i, ai in enumerate(ans):
                ai.temp_status = i in selections

        return _handler

    @cached_property
    def _get_sorted_analyses(self):
        return sorted(self.analyses, key=lambda x: x.timestamp)

    @cached_property
    def _get_sorted_references(self):
        return sorted(self.references, key=lambda x: x.timestamp)

    @on_trait_change('graph:regression_results')
    def _update_regression(self, new):
        # return

        key = 'Unknowns-predicted{}'
        #necessary to handle user excluding points
        if self.binned_analyses:
            gen = self._graph_generator()

            c = 0
            for j, fit in enumerate(gen):
                for i, g in enumerate(self.binned_analyses):
                    try:
                        plotobj, reg = new[c]
                    except IndexError:
                        break

                    if issubclass(type(reg), BaseRegressor):
                        k = key.format(i)
                        self._set_values(fit, plotobj, reg, k, g.unknowns)
                    c += 1
        else:
            key = key.format(0)
            gen = self._graph_generator()
            for fit, (plotobj, reg) in zip(gen, new):
                if issubclass(type(reg), BaseRegressor):
                    self._set_values(fit, plotobj, reg, key)

    def _set_values(self, fit, plotobj, reg, key, ans=None):

        iso = fit.name
        if key in plotobj.plots:
            scatter = plotobj.plots[key][0]
            p_uys, p_ues = self.set_interpolated_values(iso, reg, ans)
            scatter.value.set_data(p_uys)
            scatter.yerror.set_data(p_ues)

    def _clean_references(self):
        return [ri for ri in self.references if ri.temp_status == 0]
Ejemplo n.º 27
0
class Volume(Module):
    """The Volume module visualizes scalar fields using volumetric
    visualization techniques.  This supports ImageData and
    UnstructuredGrid data.  It also supports the FixedPointRenderer
    for ImageData.  However, the performance is slow so your best bet
    is probably with the ImageData based renderers.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    volume_mapper_type = DEnum(values_name='_mapper_types',
                               desc='volume mapper to use')

    ray_cast_function_type = DEnum(values_name='_ray_cast_functions',
                                   desc='Ray cast function to use')

    volume = ReadOnly

    volume_mapper = Property(record=True)

    volume_property = Property(record=True)

    ray_cast_function = Property(record=True)

    lut_manager = Instance(VolumeLUTManager,
                           args=(),
                           allow_none=False,
                           record=True)

    input_info = PipelineInfo(datasets=['image_data', 'unstructured_grid'],
                              attribute_types=['any'],
                              attributes=['scalars'])

    ########################################
    # View related code.

    update_ctf = Button('Update CTF')

    view = View(Group(Item(name='_volume_property',
                           style='custom',
                           editor=VolumePropertyEditor,
                           resizable=True),
                      Item(name='update_ctf'),
                      label='CTF',
                      show_labels=False),
                Group(
                    Item(name='volume_mapper_type'),
                    Group(Item(name='_volume_mapper',
                               style='custom',
                               resizable=True),
                          show_labels=False),
                    Item(name='ray_cast_function_type'),
                    Group(Item(name='_ray_cast_function',
                               enabled_when='len(_ray_cast_functions) > 0',
                               style='custom',
                               resizable=True),
                          show_labels=False),
                    label='Mapper',
                ),
                Group(Item(name='_volume_property',
                           style='custom',
                           resizable=True),
                      label='Property',
                      show_labels=False),
                Group(Item(name='volume',
                           style='custom',
                           editor=InstanceEditor(),
                           resizable=True),
                      label='Volume',
                      show_labels=False),
                Group(Item(name='lut_manager', style='custom', resizable=True),
                      label='Legend',
                      show_labels=False),
                resizable=True)

    ########################################
    # Private traits
    _volume_mapper = Instance(tvtk.AbstractVolumeMapper)
    _volume_property = Instance(tvtk.VolumeProperty)
    _ray_cast_function = Instance(tvtk.Object)

    _mapper_types = List(Str, [
        'TextureMapper2D',
        'RayCastMapper',
    ])

    _available_mapper_types = List(Str)

    _ray_cast_functions = List(Str)

    current_range = Tuple

    # The color transfer function.
    _ctf = Instance(ColorTransferFunction)
    # The opacity values.
    _otf = Instance(PiecewiseFunction)

    # A cache for the mappers, a dict keyed by class.
    _mapper_cache = Dict

    ######################################################################
    # `object` interface
    ######################################################################
    def __get_pure_state__(self):
        d = super(Volume, self).__get_pure_state__()
        d['ctf_state'] = save_ctfs(self._volume_property)
        for name in ('current_range', '_ctf', '_otf'):
            d.pop(name, None)
        return d

    def __set_pure_state__(self, state):
        self.volume_mapper_type = state['_volume_mapper_type']
        state_pickler.set_state(self, state, ignore=['ctf_state'])
        ctf_state = state['ctf_state']
        ctf, otf = load_ctfs(ctf_state, self._volume_property)
        self._ctf = ctf
        self._otf = otf
        self._update_ctf_fired()

    ######################################################################
    # `Module` interface
    ######################################################################
    def start(self):
        super(Volume, self).start()
        self.lut_manager.start()

    def stop(self):
        super(Volume, self).stop()
        self.lut_manager.stop()

    def setup_pipeline(self):
        """Override this method so that it *creates* the tvtk
        pipeline.
        """
        v = self.volume = tvtk.Volume()
        vp = self._volume_property = tvtk.VolumeProperty()

        self._ctf = ctf = default_CTF(0, 255)
        self._otf = otf = default_OTF(0, 255)
        vp.set_color(ctf)
        vp.set_scalar_opacity(otf)
        vp.shade = True
        vp.interpolation_type = 'linear'
        v.property = vp

        v.on_trait_change(self.render)
        vp.on_trait_change(self.render)

        available_mappers = find_volume_mappers()
        if is_volume_pro_available():
            self._mapper_types.append('VolumeProMapper')
            available_mappers.append('VolumeProMapper')

        self._available_mapper_types = available_mappers
        if 'FixedPointVolumeRayCastMapper' in available_mappers:
            self._mapper_types.append('FixedPointVolumeRayCastMapper')

        self.actors.append(v)

    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.

        This method is invoked (automatically) when any of the inputs
        sends a `pipeline_changed` event.
        """
        mm = self.module_manager
        if mm is None:
            return

        dataset = mm.source.get_output_dataset()

        ug = hasattr(tvtk, 'UnstructuredGridVolumeMapper')
        if dataset.is_a('vtkMultiBlockDataSet'):
            if 'MultiBlockVolumeMapper' not in self._available_mapper_types:
                error('Your version of VTK does not support '
                      'Multiblock volume rendering')
                return
        elif dataset.is_a('vtkUniformGridAMR'):
            if 'AMRVolumeMapper' not in self._available_mapper_types:
                error('Your version of VTK does not support '
                      'AMR volume rendering')
                return
        elif ug:
            if not dataset.is_a('vtkImageData') \
                   and not dataset.is_a('vtkUnstructuredGrid'):
                error('Volume rendering only works with '
                      'StructuredPoints/ImageData/UnstructuredGrid datasets')
                return
        elif not dataset.is_a('vtkImageData'):
            error('Volume rendering only works with '
                  'StructuredPoints/ImageData datasets')
            return

        self._setup_mapper_types()
        self._setup_current_range()
        self._volume_mapper_type_changed(self.volume_mapper_type)
        self._update_ctf_fired()
        self.pipeline_changed = True

    def update_data(self):
        """Override this method so that it flushes the vtk pipeline if
        that is necessary.

        This method is invoked (automatically) when any of the inputs
        sends a `data_changed` event.
        """
        self._setup_mapper_types()
        self._setup_current_range()
        self._update_ctf_fired()
        self.data_changed = True

    ######################################################################
    # Non-public methods.
    ######################################################################
    def _get_image_data_volume_mappers(self):
        check = ('SmartVolumeMapper', 'GPUVolumeRayCastMapper',
                 'OpenGLGPUVolumeRayCastMapper')
        return [x for x in check if x in self._available_mapper_types]

    def _setup_mapper_types(self):
        """Sets up the mapper based on input data types.
        """
        dataset = self.module_manager.source.get_output_dataset()
        if dataset.is_a('vtkMultiBlockDataSet'):
            if 'MultiBlockVolumeMapper' in self._available_mapper_types:
                self._mapper_types = ['MultiBlockVolumeMapper']
        elif dataset.is_a('vtkUniformGridAMR'):
            if 'AMRVolumeMapper' not in self._available_mapper_types:
                self._mapper_types = ['AMRVolumeMapper']
        elif dataset.is_a('vtkUnstructuredGrid'):
            if hasattr(tvtk, 'UnstructuredGridVolumeMapper'):
                check = [
                    'UnstructuredGridVolumeZSweepMapper',
                    'UnstructuredGridVolumeRayCastMapper',
                ]
                mapper_types = []
                for mapper in check:
                    if mapper in self._available_mapper_types:
                        mapper_types.append(mapper)
                if len(mapper_types) == 0:
                    mapper_types = ['']
                self._mapper_types = mapper_types
                return
        else:
            mapper_types = self._get_image_data_volume_mappers()
            if dataset.point_data.scalars.data_type not in \
               [vtkConstants.VTK_UNSIGNED_CHAR,
                vtkConstants.VTK_UNSIGNED_SHORT]:
                if 'FixedPointVolumeRayCastMapper' \
                       in self._available_mapper_types:
                    mapper_types.append('FixedPointVolumeRayCastMapper')
                elif len(mapper_types) == 0:
                    error('Available volume mappers only work with '
                          'unsigned_char or unsigned_short datatypes')
            else:
                check = [
                    'FixedPointVolumeRayCastMapper', 'VolumeProMapper',
                    'TextureMapper2D', 'RayCastMapper', 'TextureMapper3D'
                ]
                for mapper in check:
                    if mapper in self._available_mapper_types:
                        mapper_types.append(mapper)
            self._mapper_types = mapper_types

    def _setup_current_range(self):
        mm = self.module_manager
        # Set the default name and range for our lut.
        lm = self.lut_manager
        slm = mm.scalar_lut_manager
        lm.trait_set(default_data_name=slm.default_data_name,
                     default_data_range=slm.default_data_range)

        # Set the current range.
        dataset = mm.source.get_output_dataset()
        dsh = DataSetHelper(dataset)
        name, rng = dsh.get_range('scalars', 'point')
        if name is None:
            error('No scalars in input data!')
            rng = (0, 255)

        if self.current_range != rng:
            self.current_range = rng

    def _get_volume_mapper(self):
        return self._volume_mapper

    def _get_volume_property(self):
        return self._volume_property

    def _get_ray_cast_function(self):
        return self._ray_cast_function

    def _get_mapper(self, klass):
        """ Return a mapper of the given class. Either from the cache or by
        making a new one.
        """
        result = self._mapper_cache.get(klass)
        if result is None:
            result = klass()
            self._mapper_cache[klass] = result
        return result

    def _volume_mapper_type_changed(self, value):
        mm = self.module_manager
        if mm is None:
            return

        old_vm = self._volume_mapper
        if old_vm is not None:
            old_vm.on_trait_change(self.render, remove=True)
            try:
                old_vm.remove_all_input_connections(0)
            except AttributeError:
                pass

        if value == 'RayCastMapper':
            new_vm = self._get_mapper(tvtk.VolumeRayCastMapper)
            self._volume_mapper = new_vm
            self._ray_cast_functions = [
                'RayCastCompositeFunction', 'RayCastMIPFunction',
                'RayCastIsosurfaceFunction'
            ]
            new_vm.volume_ray_cast_function = self._get_mapper(
                tvtk.VolumeRayCastCompositeFunction)
        elif value == 'MultiBlockVolumeMapper':
            new_vm = self._get_mapper(tvtk.MultiBlockVolumeMapper)
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'AMRVolumeMapper':
            new_vm = self._get_mapper(tvtk.AMRVolumeMapper)
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'SmartVolumeMapper':
            new_vm = self._get_mapper(tvtk.SmartVolumeMapper)
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'GPUVolumeRayCastMapper':
            new_vm = self._get_mapper(tvtk.GPUVolumeRayCastMapper)
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'OpenGLGPUVolumeRayCastMapper':
            new_vm = self._get_mapper(tvtk.OpenGLGPUVolumeRayCastMapper)
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'TextureMapper2D':
            new_vm = self._get_mapper(tvtk.VolumeTextureMapper2D)
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'TextureMapper3D':
            new_vm = self._get_mapper(tvtk.VolumeTextureMapper3D)
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'VolumeProMapper':
            new_vm = self._get_mapper(tvtk.VolumeProMapper)
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'FixedPointVolumeRayCastMapper':
            new_vm = self._get_mapper(tvtk.FixedPointVolumeRayCastMapper)
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'UnstructuredGridVolumeRayCastMapper':
            new_vm = self._get_mapper(tvtk.UnstructuredGridVolumeRayCastMapper)
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'UnstructuredGridVolumeZSweepMapper':
            new_vm = self._get_mapper(tvtk.UnstructuredGridVolumeZSweepMapper)
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']

        src = mm.source
        self.configure_input(new_vm, src.outputs[0])
        self.volume.mapper = new_vm
        new_vm.on_trait_change(self.render)

    def _update_ctf_fired(self):
        set_lut(self.lut_manager.lut, self._volume_property)
        self.render()

    def _current_range_changed(self, old, new):
        rescale_ctfs(self._volume_property, new)
        self.render()

    def _ray_cast_function_type_changed(self, old, new):
        rcf = self.ray_cast_function
        if len(old) > 0:
            rcf.on_trait_change(self.render, remove=True)

        if len(new) > 0:
            new_rcf = getattr(tvtk, 'Volume%s' % new)()
            new_rcf.on_trait_change(self.render)
            self._volume_mapper.volume_ray_cast_function = new_rcf
            self._ray_cast_function = new_rcf
        else:
            self._ray_cast_function = None

        self.render()

    def _scene_changed(self, old, new):
        super(Volume, self)._scene_changed(old, new)
        self.lut_manager.scene = new
Ejemplo n.º 28
0
class NamespaceView(View):
    """ A view containing the contents of the Python shell namespace. """

    #### 'IView' interface ####################################################

    # The part's globally unique identifier.
    id = "enthought.plugins.python_shell.view.namespace_view"

    # The view's name.
    name = "Namespace"

    # The default position of the view relative to the item specified in the
    # 'relative_to' trait.
    position = "left"

    #### 'NamespaceView' interface ############################################

    # The bindings in the namespace.  This is a list of HasTraits objects with
    # 'name', 'type' and 'module' string attributes.
    bindings = Property(List, observe="namespace")

    shell_view = Instance(PythonShellView)

    namespace = DelegatesTo("shell_view")

    # The default traits UI view.
    traits_view = TraitsView(
        VGroup(
            Item(
                "bindings",
                id="table",
                editor=table_editor,
                springy=True,
                resizable=True,
            ),
            show_border=True,
            show_labels=False,
        ),
        resizable=True,
    )

    ###########################################################################
    # 'View' interface.
    ###########################################################################

    def create_control(self, parent):
        """ Creates the toolkit-specific control that represents the view.

        'parent' is the toolkit-specific control that is the view's parent.

        """

        self.ui = self.edit_traits(parent=parent, kind="subpanel")

        self.shell_view = self.window.application.get_service(IPythonShell)
        # 'shell_view' is an instance of the class PythonShellView from the
        # module envisage.plugins.python_shell.view.python_shell_view.

        return self.ui.control

    ###########################################################################
    # 'NamespaceView' interface.
    ###########################################################################

    #### Properties ###########################################################

    @cached_property
    def _get_bindings(self):
        """ Property getter. """

        if self.shell_view is None:
            return []

        class item(HasTraits):
            name = Str
            type = Str
            module = Str

        data = [
            item(
                name=name, type=type_to_str(value), module=module_to_str(value)
            )
            for name, value in self.shell_view.namespace.items()
        ]

        return data
Ejemplo n.º 29
0
class MRISubjectSource(HasPrivateTraits):
    """Find subjects in SUBJECTS_DIR and select one.

    Parameters
    ----------
    subjects_dir : directory
        SUBJECTS_DIR.
    subject : str
        Subject, corresponding to a folder in SUBJECTS_DIR.
    """

    refresh = Event(desc="Refresh the subject list based on the directory "
                    "structure of subjects_dir.")

    # settings
    subjects_dir = Directory(exists=True)
    subjects = Property(List(Str), depends_on=['subjects_dir', 'refresh'])
    subject = Enum(values='subjects')
    use_high_res_head = Bool(True)

    # info
    can_create_fsaverage = Property(Bool,
                                    depends_on=['subjects_dir', 'subjects'])
    subject_has_bem = Property(Bool,
                               depends_on=['subjects_dir', 'subject'],
                               desc="whether the subject has a file matching "
                               "the bem file name pattern")
    bem_pattern = Property(depends_on='mri_dir')

    @cached_property
    def _get_can_create_fsaverage(self):
        if not os.path.exists(self.subjects_dir):
            return False
        if 'fsaverage' in self.subjects:
            return False
        return True

    @cached_property
    def _get_mri_dir(self):
        if not self.subject:
            return
        elif not self.subjects_dir:
            return
        else:
            return os.path.join(self.subjects_dir, self.subject)

    @cached_property
    def _get_subjects(self):
        sdir = self.subjects_dir
        is_dir = sdir and os.path.isdir(sdir)
        if is_dir:
            dir_content = os.listdir(sdir)
            subjects = [s for s in dir_content if _is_mri_subject(s, sdir)]
            if len(subjects) == 0:
                subjects.append('')
        else:
            subjects = ['']

        return subjects

    @cached_property
    def _get_subject_has_bem(self):
        if not self.subject:
            return False
        return _mri_subject_has_bem(self.subject, self.subjects_dir)

    def create_fsaverage(self):  # noqa: D102
        if not self.subjects_dir:
            err = ("No subjects directory is selected. Please specify "
                   "subjects_dir first.")
            raise RuntimeError(err)

        mne_root = get_mne_root()
        if mne_root is None:
            err = ("MNE contains files that are needed for copying the "
                   "fsaverage brain. Please install MNE and try again.")
            raise RuntimeError(err)
        fs_home = get_fs_home()
        if fs_home is None:
            err = ("FreeSurfer contains files that are needed for copying the "
                   "fsaverage brain. Please install FreeSurfer and try again.")
            raise RuntimeError(err)

        create_default_subject(mne_root,
                               fs_home,
                               subjects_dir=self.subjects_dir)
        self.refresh = True
        self.subject = 'fsaverage'
Ejemplo n.º 30
0
class Axis(ConfigLoadable):
    '''
    '''
    id = Int
    #    name = Str
    position = Float
    negative_limit = Float
    positive_limit = Float
    pdir = Str
    parent = Any(transient=True)
    calculate_parameters = Bool(True)
    drive_ratio = Float(1)

    velocity = Property(depends_on='_velocity')
    _velocity = Float(enter_set=True, auto_set=False)

    acceleration = Property(depends_on='_acceleration')
    _acceleration = Float(enter_set=True, auto_set=False)

    deceleration = Property(depends_on='_deceleration')
    _deceleration = Float(enter_set=True, auto_set=False)

    machine_velocity = Float
    machine_acceleration = Float
    machine_deceleration = Float
    # sets handled by child class

    nominal_velocity = Float
    nominal_acceleration = Float
    nominal_deceleration = Float

    sign = CInt(1)

    def _get_velocity(self):
        return self._velocity

    def _get_acceleration(self):
        return self._acceleration

    def _get_deceleration(self):
        return self._deceleration

    def upload_parameters_to_device(self):
        pass

    @on_trait_change('_velocity, _acceleration, _deceleration')
    def update_machine_values(self, obj, name, old, new):
        setattr(self, 'machine{}'.format(name), new)

    def _calibration_changed(self):
        self.parent.update_axes()

    def simple_view(self):
        v = View(
            Item('calculate_parameters'),
            Item('velocity',
                 format_str='%0.3f',
                 enabled_when='not calculate_parameters'),
            Item('acceleration',
                 format_str='%0.3f',
                 enabled_when='not calculate_parameters'),
            Item('deceleration',
                 format_str='%0.3f',
                 enabled_when='not calculate_parameters'), Item('drive_ratio'))
        return v

    def full_view(self):
        return self.simple_view()

    def dump(self):
        '''
        '''
        pass
#        self.loaded = False
#
#        p = os.path.join(self.pdir, '.%s' % self.name)
#        with open(p, 'w') as f:
#            pickle.dump(self, f)
#    def load_parameters_from_config(self, path):
#        self.config_path = path
#        self._load_parameters_from_config(path)
#
#    def load_parameters(self, pdir):
#        '''
#        '''
# #        self.pdir = pdir
# #        p = os.path.join(pdir, '.%s' % self.name)
# #
# #        if os.path.isfile(p):
# #            return p
# #        else:
#        self.load(pdir)

    def save(self):
        pass

    def ask(self, cmd):
        return self.parent.ask(cmd)

    def _get_parameters(self, path):
        '''
  
        '''
        #        cp = ConfigParser.ConfigParser()
        #        cp.read())
        params = []
        #        if path is None:
        if not os.path.isfile(path):
            path = os.path.join(path, '{}axis.cfg'.format(self.name))

        cp = self.get_configuration(path)
        if cp:
            params = [item for s in cp.sections() for item in cp.items(s)]

#        for ai in a:
#            print ai
#
#        for s in cp.sections():
#            for i in cp.items(s):
#                params.append(i)
        return params

    def _validate_float(self, v):
        try:
            v = float(v)
            return v
        except ValueError:
            pass