Пример #1
0
class TitleTemplater(BaseTemplater):
    attributes = List([
        'Project', 'Sample', 'Identifier', 'Aliquot', 'Material',
        'AlphaCounter', 'NumericCounter', '<SPACE>'
    ])

    attribute_formats = {
        'sample': '',
        'identifier': '',
        'project': '',
        'aliquot': '02n',
        'material': '',
        'numericcounter': '',
        'alphacounter': ''
    }

    example_context = {
        'sample': 'NM-001',
        'identifier': '20001',
        'project': 'J-Curve',
        'aliquot': 1,
        'material': 'GMC',
        'numericcounter': 1,
        'alphacounter': 'A'
    }

    base_predefined_labels = List([
        'Sample ( Identifier )', 'Sample ( Identifier - Aliquot )',
        'Sample ( Identifier - Aliquot , Material )',
        'AlphaCounter . <SPACE> Sample ( Identifier - Aliquot , Material )',
        'Sample', 'Project <SPACE> Sample ( Identifier )'
    ])

    delimiter = Str
    delimiters = Dict({
        ',': 'Comma',
        '\t': 'Tab',
        ' ': 'Space',
        ':': 'Colon',
        ';': 'Semicolon'
    })

    example = Property(
        depends_on='label, delimiter, leading_text, trailing_text')
    multi_group_example = Property(
        depends_on='label, delimiter, leading_text, trailing_text')
    leading_text = Str
    trailing_text = Str
    leading_texts = List(['Project'])
    trailing_texts = List(['Project'])
    persistence_name = 'title_maker'

    def _get_example(self):
        return self._assemble_example(1)

    def _get_multi_group_example(self):
        return self._assemble_example(2)

    def _assemble_example(self, n):
        f = self.formatter
        ts = []
        for _ in range(n):
            ts.append(f.format(**self.example_context))

        t = self.delimiter.join(ts)
        lt = self.leading_text
        if lt:
            if lt.lower() in self.example_context:
                lt = self.example_context[lt.lower()]
            t = '{} {}'.format(lt, t)

        tt = self.trailing_text
        if tt:
            if tt.lower() in self.example_context:
                tt = self.example_context[tt.lower()]
            t = '{} {}'.format(t, tt)
        return t
Пример #2
0
class DemoFile(DemoFileBase):

    #: Source code for the demo:
    source = Code()

    #: Demo object whose traits UI is to be displayed:
    demo = Instance(HasTraits)

    #: Local namespace for executed code:
    locals = Dict(Str, Any)

    def init(self):
        super(DemoFile, self).init()
        description, source = parse_source(self.path)
        self.description = publish_html_str(description, self.css_filename)
        self.source = source
        self.run_code()

    def run_code(self):
        """ Runs the code associated with this demo file.
        """
        try:
            # Get the execution context dictionary:
            locals = self.parent.init_dic
            locals["__name__"] = "___main___"
            locals["__file__"] = self.path
            sys.modules["__main__"].__file__ = self.path

            exec(self.source, locals, locals)

            demo = self._get_object("modal_popup", locals)
            if demo is not None:
                demo = ModalDemoButton(demo=demo)
            else:
                demo = self._get_object("popup", locals)
                if demo is not None:
                    demo = DemoButton(demo=demo)
                else:
                    demo = self._get_object("demo", locals)
        except Exception:
            traceback.print_exc()
        else:
            self.demo = demo
        self.locals = locals

    # -------------------------------------------------------------------------
    #  Get a specified object from the execution dictionary:
    # -------------------------------------------------------------------------

    def _get_object(self, name, dic):
        object = dic.get(name) or dic.get(name.capitalize())
        if object is not None:
            if isinstance(type(object), type):
                try:
                    object = object()
                except Exception:
                    pass

            if isinstance(object, HasTraits):
                return object

        return None
Пример #3
0
class Context(HasTraits):
    """ The base class for all naming contexts. """

    # Keys for environment properties.
    INITIAL_CONTEXT_FACTORY = INITIAL_CONTEXT_FACTORY
    OBJECT_FACTORIES = OBJECT_FACTORIES
    STATE_FACTORIES = STATE_FACTORIES

    # Non-JNDI.
    TYPE_MANAGER = TYPE_MANAGER

    #### 'Context' interface ##################################################

    # The naming environment in effect for this context.
    environment = Dict(ENVIRONMENT)

    # The name of the context within its own namespace.
    namespace_name = Property(Str)

    # The type manager in the context's environment (used to create context
    # adapters etc.).
    #
    # fixme: This is an experimental 'convenience' trait, since it is common
    # to get hold of the context's type manager to see if some object has a
    # context adapter.
    type_manager = Property(Instance(TypeManager))

    #### Events ####

    # Fired when an object has been added to the context (either via 'bind' or
    # 'create_subcontext').
    object_added = Event(NamingEvent)

    # Fired when an object has been changed (via 'rebind').
    object_changed = Event(NamingEvent)

    # Fired when an object has been removed from the context (either via
    # 'unbind' or 'destroy_subcontext').
    object_removed = Event(NamingEvent)

    # Fired when an object in the context has been renamed (via 'rename').
    object_renamed = Event(NamingEvent)

    # Fired when the contents of the context have changed dramatically.
    context_changed = Event(NamingEvent)

    #### Protected 'Context' interface #######################################

    # The bindings in the context.
    _bindings = Dict(Str, Any)

    ###########################################################################
    # 'Context' interface.
    ###########################################################################

    #### Properties ###########################################################

    def _get_namespace_name(self):
        """
        Return the name of the context within its own namespace.

        That is the full-path, through the namespace this context participates
        in, to get to this context.  For example, if the root context of the
        namespace was called 'Foo', and there was a subcontext of that called
        'Bar', and we were within that and called 'Baz', then this should
        return 'Foo/Bar/Baz'.

        """

        # FIXME: We'd like to raise an exception and force implementors to
        # decide what to do.  However, it appears to be pretty common that
        # most Context implementations do not override this method -- possibly
        # because the comments aren't clear on what this is supposed to be?
        #
        # Anyway, if we raise an exception then it is impossible to use any
        # evaluations when building a Traits UI for a Context.  That is, the
        # Traits UI can't include items that have a 'visible_when' or
        # 'enabled_when' evaluation.  This is because the Traits evaluation
        # code calls the 'get()' method on the Context which attempts to
        # retrieve the current namespace_name value.
        #raise OperationNotSupportedError()
        return ''

    def _get_type_manager(self):
        """ Returns the type manager in the context's environment.

        This will return None if no type manager was used to create the initial
        context.

        """

        return self.environment.get(self.TYPE_MANAGER)

    #### Methods ##############################################################

    def bind(self, name, obj, make_contexts=False):
        """ Binds a name to an object.

        If 'make_contexts' is True then any missing intermediate contexts are
        created automatically.

        """

        if len(name) == 0:
            raise InvalidNameError('empty name')

        # Parse the name.
        components = self._parse_name(name)

        # If there is exactly one component in the name then the operation
        # takes place in this context.
        if len(components) == 1:
            atom = components[0]

            # Is the name already bound?
            if self._is_bound(atom):
                raise NameAlreadyBoundError(name)

            # Do the actual bind.
            self._bind(atom, obj)

            # Trait event notification.
            self.object_added = NamingEvent(
                new_binding=Binding(name=name, obj=obj, context=self))

        # Otherwise, attempt to continue resolution into the next context.
        else:
            if not self._is_bound(components[0]):
                if make_contexts:
                    self._create_subcontext(components[0])

                else:
                    raise NameNotFoundError(components[0])

            next_context = self._get_next_context(components[0])
            next_context.bind('/'.join(components[1:]), obj, make_contexts)

        return

    def rebind(self, name, obj, make_contexts=False):
        """ Binds an object to a name that may already be bound.

        If 'make_contexts' is True then any missing intermediate contexts are
        created automatically.

        The object may be a different object but may also be the same object
        that is already bound to the specified name. The name may or may not be
        already used. Think of this as a safer version of 'bind' since this
        one will never raise an exception regarding a name being used.

        """

        if len(name) == 0:
            raise InvalidNameError('empty name')

        # Parse the name.
        components = self._parse_name(name)

        # If there is exactly one component in the name then the operation
        # takes place in this context.
        if len(components) == 1:
            # Do the actual rebind.
            self._rebind(components[0], obj)

            # Trait event notification.
            self.object_changed = NamingEvent(
                new_binding=Binding(name=name, obj=obj, context=self))

        # Otherwise, attempt to continue resolution into the next context.
        else:
            if not self._is_bound(components[0]):
                if make_contexts:
                    self._create_subcontext(components[0])

                else:
                    raise NameNotFoundError(components[0])

            next_context = self._get_next_context(components[0])
            next_context.rebind('/'.join(components[1:]), obj, make_contexts)

        return

    def unbind(self, name):
        """ Unbinds a name. """

        if len(name) == 0:
            raise InvalidNameError('empty name')

        # Parse the name.
        components = self._parse_name(name)

        # If there is exactly one component in the name then the operation
        # takes place in this context.
        if len(components) == 1:
            atom = components[0]

            if not self._is_bound(atom):
                raise NameNotFoundError(name)

            # Lookup the object that we are unbinding to use in the event
            # notification.
            obj = self._lookup(atom)

            # Do the actual unbind.
            self._unbind(atom)

            # Trait event notification.
            self.object_removed = NamingEvent(
                old_binding=Binding(name=name, obj=obj, context=self))

        # Otherwise, attempt to continue resolution into the next context.
        else:
            if not self._is_bound(components[0]):
                raise NameNotFoundError(components[0])

            next_context = self._get_next_context(components[0])
            next_context.unbind('/'.join(components[1:]))

        return

    def rename(self, old_name, new_name):
        """ Binds a new name to an object. """

        if len(old_name) == 0 or len(new_name) == 0:
            raise InvalidNameError('empty name')

        # Parse the names.
        old_components = self._parse_name(old_name)
        new_components = self._parse_name(new_name)

        # If there is axactly one component in BOTH names then the operation
        # takes place ENTIRELY in this context.
        if len(old_components) == 1 and len(new_components) == 1:
            # Is the old name actually bound?
            if not self._is_bound(old_name):
                raise NameNotFoundError(old_name)

            # Is the new name already bound?
            if self._is_bound(new_name):
                raise NameAlreadyBoundError(new_name)

            # Do the actual rename.
            self._rename(old_name, new_name)

            # Lookup the object that we are renaming to use in the event
            # notification.
            obj = self._lookup(new_name)

            # Trait event notification.
            self.object_renamed = NamingEvent(
                old_binding=Binding(name=old_name, obj=obj, context=self),
                new_binding=Binding(name=new_name, obj=obj, context=self))

        else:
            # fixme: This really needs to be transactional in case the bind
            # succeeds but the unbind fails.  To be safe should we just not
            # support cross-context renaming for now?!?!
            #
            # Lookup the object.
            obj = self.lookup(old_name)

            # Bind the new name.
            self.bind(new_name, obj)

            # Unbind the old one.
            self.unbind(old_name)

        return

    def lookup(self, name):
        """ Resolves a name relative to this context. """

        # If the name is empty we return the context itself.
        if len(name) == 0:
            # fixme: The JNDI spec. says that this should return a COPY of
            # the context.
            return self

        # Parse the name.
        components = self._parse_name(name)

        # If there is exactly one component in the name then the operation
        # takes place in this context.
        if len(components) == 1:
            atom = components[0]

            if not self._is_bound(atom):
                raise NameNotFoundError(name)

            # Do the actual lookup.
            obj = self._lookup(atom)

        # Otherwise, attempt to continue resolution into the next context.
        else:
            if not self._is_bound(components[0]):
                raise NameNotFoundError(components[0])

            next_context = self._get_next_context(components[0])
            obj = next_context.lookup('/'.join(components[1:]))

        return obj

    # fixme: Non-JNDI
    def lookup_binding(self, name):
        """ Looks up the binding for a name relative to this context. """

        if len(name) == 0:
            raise InvalidNameError('empty name')

        # Parse the name.
        components = self._parse_name(name)

        # If there is exactly one component in the name then the operation
        # takes place in this context.
        if len(components) == 1:
            atom = components[0]

            if not self._is_bound(atom):
                raise NameNotFoundError(name)

            # Do the actual lookup.
            binding = self._lookup_binding(atom)

        # Otherwise, attempt to continue resolution into the next context.
        else:
            if not self._is_bound(components[0]):
                raise NameNotFoundError(components[0])

            next_context = self._get_next_context(components[0])
            binding = next_context.lookup_binding('/'.join(components[1:]))

        return binding

    # fixme: Non-JNDI
    def lookup_context(self, name):
        """ Resolves a name relative to this context.

        The name MUST resolve to a context. This method is useful to return
        context adapters.

        """

        # If the name is empty we return the context itself.
        if len(name) == 0:
            # fixme: The JNDI spec. says that this should return a COPY of
            # the context.
            return self

        # Parse the name.
        components = self._parse_name(name)

        # If there is exactly one component in the name then the operation
        # takes place in this context.
        if len(components) == 1:
            atom = components[0]

            if not self._is_bound(atom):
                raise NameNotFoundError(name)

            # Do the actual lookup.
            obj = self._get_next_context(atom)

        # Otherwise, attempt to continue resolution into the next context.
        else:
            if not self._is_bound(components[0]):
                raise NameNotFoundError(components[0])

            next_context = self._get_next_context(components[0])
            obj = next_context.lookup('/'.join(components[1:]))

        return obj

    def create_subcontext(self, name):
        """ Creates a sub-context. """

        if len(name) == 0:
            raise InvalidNameError('empty name')

        # Parse the name.
        components = self._parse_name(name)

        # If there is exactly one component in the name then the operation
        # takes place in this context.
        if len(components) == 1:
            atom = components[0]

            # Is the name already bound?
            if self._is_bound(atom):
                raise NameAlreadyBoundError(name)

            # Do the actual creation of the sub-context.
            sub = self._create_subcontext(atom)

            # Trait event notification.
            self.object_added = NamingEvent(
                new_binding=Binding(name=name, obj=sub, context=self))

        # Otherwise, attempt to continue resolution into the next context.
        else:
            if not self._is_bound(components[0]):
                raise NameNotFoundError(components[0])

            next_context = self._get_next_context(components[0])
            sub = next_context.create_subcontext('/'.join(components[1:]))

        return sub

    def destroy_subcontext(self, name):
        """ Destroys a sub-context. """

        if len(name) == 0:
            raise InvalidNameError('empty name')

        # Parse the name.
        components = self._parse_name(name)

        # If there is exactly one component in the name then the operation
        # takes place in this context.
        if len(components) == 1:
            atom = components[0]

            if not self._is_bound(atom):
                raise NameNotFoundError(name)

            obj = self._lookup(atom)
            if not self._is_context(atom):
                raise NotContextError(name)

            # Do the actual destruction of the sub-context.
            self._destroy_subcontext(atom)

            # Trait event notification.
            self.object_removed = NamingEvent(
                old_binding=Binding(name=name, obj=obj, context=self))

        # Otherwise, attempt to continue resolution into the next context.
        else:
            if not self._is_bound(components[0]):
                raise NameNotFoundError(components[0])

            next_context = self._get_next_context(components[0])
            next_context.destroy_subcontext('/'.join(components[1:]))

        return

    # fixme: Non-JNDI
    def get_unique_name(self, prefix):
        """ Returns a name that is unique within the context.

        The name returned will start with the specified prefix.

        """

        return make_unique_name(prefix,
                                existing=self.list_names(''),
                                format='%s (%d)')

    def list_names(self, name=''):
        """ Lists the names bound in a context. """

        # If the name is empty then the operation takes place in this context.
        if len(name) == 0:
            names = self._list_names()

        # Otherwise, attempt to continue resolution into the next context.
        else:
            # Parse the name.
            components = self._parse_name(name)

            if not self._is_bound(components[0]):
                raise NameNotFoundError(components[0])

            next_context = self._get_next_context(components[0])
            names = next_context.list_names('/'.join(components[1:]))

        return names

    def list_bindings(self, name=''):
        """ Lists the bindings in a context. """

        # If the name is empty then the operation takes place in this context.
        if len(name) == 0:
            bindings = self._list_bindings()

        # Otherwise, attempt to continue resolution into the next context.
        else:
            # Parse the name.
            components = self._parse_name(name)

            if not self._is_bound(components[0]):
                raise NameNotFoundError(components[0])

            next_context = self._get_next_context(components[0])
            bindings = next_context.list_bindings('/'.join(components[1:]))

        return bindings

    # fixme: Non-JNDI
    def is_context(self, name):
        """ Returns True if the name is bound to a context. """

        # If the name is empty then it refers to this context.
        if len(name) == 0:
            is_context = True

        else:
            # Parse the name.
            components = self._parse_name(name)

            # If there is exactly one component in the name then the operation
            # takes place in this context.
            if len(components) == 1:
                atom = components[0]

                if not self._is_bound(atom):
                    raise NameNotFoundError(name)

                # Do the actual check.
                is_context = self._is_context(atom)

            # Otherwise, attempt to continue resolution into the next context.
            else:
                if not self._is_bound(components[0]):
                    raise NameNotFoundError(components[0])

                next_context = self._get_next_context(components[0])
                is_context = next_context.is_context('/'.join(components[1:]))

        return is_context

    # fixme: Non-JNDI
    def search(self, obj):
        """ Returns a list of namespace names that are bound to obj. """

        # don't look for None
        if obj is None:
            return []

        # Obj is bound to these names relative to this context
        names = []

        # path contain the name components down to the current context
        path = []

        self._search(obj, names, path, {})

        return names

    ###########################################################################
    # Protected 'Context' interface.
    ###########################################################################

    def _parse_name(self, name):
        """ Parse a name into a list of components.

        e.g. 'foo/bar/baz' -> ['foo', 'bar', 'baz']

        """

        return name.split('/')

    def _is_bound(self, name):
        """ Is a name bound in this context? """

        return name in self._bindings

    def _lookup(self, name):
        """ Looks up a name in this context. """

        obj = self._bindings[name]

        return naming_manager.get_object_instance(obj, name, self)

    def _lookup_binding(self, name):
        """ Looks up the binding for a name in this context. """

        return Binding(name=name, obj=self._lookup(name), context=self)

    def _bind(self, name, obj):
        """ Binds a name to an object in this context. """

        state = naming_manager.get_state_to_bind(obj, name, self)
        self._bindings[name] = state

        return

    def _rebind(self, name, obj):
        """ Rebinds a name to an object in this context. """

        self._bind(name, obj)

        return

    def _unbind(self, name):
        """ Unbinds a name from this context. """

        del self._bindings[name]

        return

    def _rename(self, old_name, new_name):
        """ Renames an object in this context. """

        # Bind the new name.
        self._bindings[new_name] = self._bindings[old_name]

        # Unbind the old one.
        del self._bindings[old_name]

        return

    def _create_subcontext(self, name):
        """ Creates a sub-context of this context. """

        sub = self.__class__(environment=self.environment)
        self._bindings[name] = sub

        return sub

    def _destroy_subcontext(self, name):
        """ Destroys a sub-context of this context. """

        del self._bindings[name]

        return

    def _list_bindings(self):
        """ Lists the bindings in this context. """

        bindings = []
        for name in self._list_names():
            bindings.append(
                Binding(name=name, obj=self._lookup(name), context=self))

        return bindings

    def _list_names(self):
        """ Lists the names bound in this context. """

        return list(self._bindings.keys())

    def _is_context(self, name):
        """ Returns True if a name is bound to a context. """

        return self._get_next_context(name) is not None

    def _get_next_context(self, name):
        """ Returns the next context. """

        obj = self._lookup(name)

        # If the object is a context then everything is just dandy.
        if isinstance(obj, Context):
            next_context = obj

        # Otherwise, instead of just giving up, see if the context has a type
        # manager that knows how to adapt the object to make it quack like a
        # context.
        else:
            next_context = self._get_context_adapter(obj)

            # If no adapter was found then we cannot continue name resolution.
            if next_context is None:
                raise NotContextError(name)

        return next_context

    def _search(self, obj, names, path, searched):
        """ Append to names any name bound to obj.
            Join path and name with '/' to for a complete name from the
            top context.
        """

        # Check the bindings recursively.
        for binding in self.list_bindings():
            if binding.obj is obj:
                path.append(binding.name)
                names.append('/'.join(path))
                path.pop()

            if isinstance( binding.obj, Context ) \
                and not binding.obj in searched:
                path.append(binding.name)
                searched[binding.obj] = True
                binding.obj._search(obj, names, path, searched)
                path.pop()

        return

    ###########################################################################
    # Private interface.
    ###########################################################################

    def _get_context_adapter(self, obj):
        """ Returns a context adapter for an object.

        Returns None if no such adapter is available.

        """

        if self.type_manager is not None:
            adapter = self.type_manager.object_as(obj,
                                                  Context,
                                                  environment=self.environment,
                                                  context=self)

        else:
            adapter = None

        return adapter
Пример #4
0
class UI(HasPrivateTraits):
    """ Information about the user interface for a View.
    """

    # -------------------------------------------------------------------------
    #  Trait definitions:
    # -------------------------------------------------------------------------

    #: The ViewElements object from which this UI resolves Include items
    view_elements = Instance(ViewElements)

    #: Context objects that the UI is editing
    context = Dict(Str, Any)

    #: Handler object used for event handling
    handler = Instance(Handler)

    #: View template used to construct the user interface
    view = Instance("traitsui.view.View")

    #: Panel or dialog associated with the user interface
    control = Any()

    #: The parent UI (if any) of this UI
    parent = Instance("UI")

    #: Toolkit-specific object that "owns" **control**
    owner = Any()

    #: UIInfo object containing context or editor objects
    info = Instance(UIInfo)

    #: Result from a modal or wizard dialog:
    result = Bool(False)

    #: Undo and Redo history
    history = Any()

    #: The KeyBindings object (if any) for this UI:
    key_bindings = Property(depends_on=["view._key_bindings", "context"])

    #: The unique ID for this UI for persistence
    id = Str()

    #: Have any modifications been made to UI contents?
    modified = Bool(False)

    #: Event when the user interface has changed
    updated = Event(Bool)

    #: Title of the dialog, if any
    title = Str()

    #: The ImageResource of the icon, if any
    icon = Image

    #: Should the created UI have scroll bars?
    scrollable = Bool(False)

    #: The number of currently pending editor error conditions
    errors = Int()

    #: The code used to rebuild an updated user interface
    rebuild = Callable()

    #: Set to True when the UI has finished being destroyed.
    destroyed = Bool(False)

    # -- Private Traits -------------------------------------------------------

    #: Original context when used with a modal dialog
    _context = Dict(Str, Any)

    #: Copy of original context used for reverting changes
    _revert = Dict(Str, Any)

    #: List of methods to call once the user interface is created
    _defined = List()

    #: List of (visible_when,Editor) pairs
    _visible = List()

    #: List of (enabled_when,Editor) pairs
    _enabled = List()

    #: List of (checked_when,Editor) pairs
    _checked = List()

    #: Search stack used while building a user interface
    _search = List()

    #: List of dispatchable Handler methods
    _dispatchers = List()

    #: List of editors used to build the user interface
    _editors = List()

    #: List of names bound to the **info** object
    _names = List()

    #: Index of currently the active group in the user interface
    _active_group = Int()

    #: List of top-level groups used to build the user interface
    _groups = Property()
    _groups_cache = Any()

    #: Count of levels of nesting for undoable actions
    _undoable = Int(-1)

    #: Code used to rebuild an updated user interface
    _rebuild = Callable()

    #: The statusbar listeners that have been set up:
    _statusbar = List()

    #: Control which gets focus after UI is created
    #: Note: this does not track focus after UI creation
    #: only used by Qt backend.
    _focus_control = Any()

    #: Does the UI contain any scrollable widgets?
    #:
    #: The _scrollable trait is set correctly, but not used currently because
    #: its value is arrived at too late to be of use in building the UI.
    _scrollable = Bool(False)

    #: Cache for key bindings.
    _key_bindings = Instance("traitsui.key_bindings.KeyBindings")

    #: List of traits that are reset when a user interface is recycled
    #: (i.e. rebuilt).
    recyclable_traits = [
        "_context",
        "_revert",
        "_defined",
        "_visible",
        "_enabled",
        "_checked",
        "_search",
        "_dispatchers",
        "_editors",
        "_names",
        "_active_group",
        "_undoable",
        "_rebuild",
        "_groups_cache",
        "_key_bindings",
        "_focus_control",
    ]

    #: List of additional traits that are discarded when a user interface is
    #: disposed.
    disposable_traits = [
        "view_elements",
        "info",
        "handler",
        "context",
        "view",
        "history",
        "key_bindings",
        "icon",
        "rebuild",
    ]

    def traits_init(self):
        """ Initializes the traits object.
        """
        self.info = UIInfo(ui=self)
        self.handler.init_info(self.info)

    def ui(self, parent, kind):
        """ Creates a user interface from the associated View template object.
        """
        if (parent is None) and (kind in kind_must_have_parent):
            kind = "live"
        self.view.on_trait_change(self._updated_changed,
                                  "updated",
                                  dispatch="ui")
        self.rebuild = getattr(toolkit(), "ui_" + kind)
        self.rebuild(self, parent)

    def dispose(self, result=None, abort=False):
        """ Disposes of the contents of a user interface.
        """
        if self.parent is not None:
            self.parent.errors -= self.errors

        if result is not None:
            self.result = result

        # Only continue if the view has not already been disposed of:
        if self.control is not None:
            # Save the user preference information for the user interface:
            if not abort:
                self.save_prefs()

            # Finish disposing of the user interface:
            self.finish()

    def recycle(self):
        """ Recycles the user interface prior to rebuilding it.
        """
        # Reset all user interface editors:
        self.reset(destroy=False)

        # Discard any context object associated with the ui view control:
        self.control._object = None

        # Reset all recyclable traits:
        self.reset_traits(self.recyclable_traits)

    def finish(self):
        """ Finishes disposing of a user interface.
        """
        # Destroy the control early to silence cascade events when the UI
        # enters an inconsistent state.
        toolkit().destroy_control(self.control)

        # Reset the contents of the user interface
        self.reset(destroy=False)

        # Make sure that 'visible', 'enabled', and 'checked' handlers are not
        # called after the editor has been disposed:
        for object in self.context.values():
            object.on_trait_change(self._evaluate_when, remove=True)

        # Notify the handler that the view has been closed:
        self.handler.closed(self.info, self.result)

        # Clear the back-link from the UIInfo object to us:
        self.info.ui = None

        # Destroy the view control:
        self.control._object = None
        self.control = None

        # Dispose of any KeyBindings object we reference:
        if self._key_bindings is not None:
            self._key_bindings.dispose()

        # Break the linkage to any objects in the context dictionary:
        self.context.clear()

        # Remove specified symbols from our dictionary to aid in clean-up:
        self.reset_traits(self.recyclable_traits)
        self.reset_traits(self.disposable_traits)

        self.destroyed = True

    def reset(self, destroy=True):
        """ Resets the contents of a user interface.
        """
        for editor in self._editors:
            if editor._ui is not None:
                # Propagate result to enclosed ui objects:
                editor._ui.result = self.result
            editor.dispose()

            # Zap the control. If there are pending events for the control in
            # the UI queue, the editor's '_update_editor' method will see that
            # the control is None and discard the update request:
            editor.control = None

        # Remove any statusbar listeners that have been set up:
        for object, handler, name in self._statusbar:
            object.observe(handler, name, remove=True, dispatch="ui")

        del self._statusbar[:]

        if destroy:
            toolkit().destroy_children(self.control)

        for dispatcher in self._dispatchers:
            dispatcher.remove()

    def find(self, include):
        """ Finds the definition of the specified Include object in the current
            user interface building context.
        """
        context = self.context
        result = None

        # Get the context 'object' (if available):
        if len(context) == 1:
            object = list(context.values())[0]
        else:
            object = context.get("object")

        # Try to use our ViewElements objects:
        ve = self.view_elements

        # If none specified, try to get it from the UI context:
        if (ve is None) and (object is not None):
            # Use the context object's ViewElements (if available):
            ve = object.trait_view_elements()

        # Ask the ViewElements to find the requested item for us:
        if ve is not None:
            result = ve.find(include.id, self._search)

        # If not found, then try to search the 'handler' and 'object' for a
        # method we can call that will define it:
        if result is None:
            handler = context.get("handler")
            if handler is not None:
                method = getattr(handler, include.id, None)
                if callable(method):
                    result = method()

            if (result is None) and (object is not None):
                method = getattr(object, include.id, None)
                if callable(method):
                    result = method()

        return result

    def push_level(self):
        """ Returns the current search stack level.
        """
        return len(self._search)

    def pop_level(self, level):
        """ Restores a previously pushed search stack level.
        """
        del self._search[:len(self._search) - level]

    def prepare_ui(self):
        """ Performs all processing that occurs after the user interface is
            created.
        """
        # Invoke all of the editor 'name_defined' methods we've accumulated:
        info = self.info.trait_set(initialized=False)
        for method in self._defined:
            method(info)

        # Then reset the list, since we don't need it anymore:
        del self._defined[:]

        # Synchronize all context traits with associated editor traits:
        self.sync_view()

        # Hook all keyboard events:
        toolkit().hook_events(self, self.control, "keys", self.key_handler)

        # Hook all events if the handler is an extended 'ViewHandler':
        handler = self.handler
        if isinstance(handler, ViewHandler):
            toolkit().hook_events(self, self.control)

        # Invoke the handler's 'init' method, and abort if it indicates
        # failure:
        if handler.init(info) == False:
            raise TraitError("User interface creation aborted")

        # For each Handler method whose name is of the form
        # 'object_name_changed', where 'object' is the name of an object in the
        # UI's 'context', create a trait notification handler that will call
        # the method whenever 'object's 'name' trait changes. Also invoke the
        # method immediately so initial user interface state can be correctly
        # set:
        context = self.context
        for name in self._each_trait_method(handler):
            if name[-8:] == "_changed":
                prefix = name[:-8]
                col = prefix.find("_", 1)
                if col >= 0:
                    object = context.get(prefix[:col])
                    if object is not None:
                        method = getattr(handler, name)
                        trait_name = prefix[col + 1:]
                        self._dispatchers.append(
                            Dispatcher(method, info, object, trait_name))
                        if object.base_trait(trait_name).type != "event":
                            method(info)

        # If there are any Editor object's whose 'visible', 'enabled' or
        # 'checked' state is controlled by a 'visible_when', 'enabled_when' or
        # 'checked_when' expression, set up an 'anytrait' changed notification
        # handler on each object in the 'context' that will cause the
        # 'visible', 'enabled' or 'checked' state of each affected Editor to be
        #  set. Also trigger the evaluation immediately, so the visible,
        # enabled or checked state of each Editor can be correctly initialized:
        if (len(self._visible) + len(self._enabled) + len(self._checked)) > 0:
            for object in context.values():
                object.on_trait_change(self._evaluate_when, dispatch="ui")
            self._do_evaluate_when(at_init=True)

        # Indicate that the user interface has been initialized:
        info.initialized = True

    def sync_view(self):
        """ Synchronize context object traits with view editor traits.
        """
        for name, object in self.context.items():
            self._sync_view(name, object, "sync_to_view", "from")
            self._sync_view(name, object, "sync_from_view", "to")
            self._sync_view(name, object, "sync_with_view", "both")

    def _sync_view(self, name, object, metadata, direction):
        info = self.info
        for trait_name, trait in object.traits(**{metadata: is_str}).items():
            for sync in getattr(trait, metadata).split(","):
                try:
                    editor_id, editor_name = [
                        item.strip() for item in sync.split(".")
                    ]
                except:
                    raise TraitError(
                        "The '%s' metadata for the '%s' trait in "
                        "the '%s' context object should be of the form: "
                        "'id1.trait1[,...,idn.traitn]." %
                        (metadata, trait_name, name))

                editor = getattr(info, editor_id, None)
                if editor is not None:
                    editor.sync_value("%s.%s" % (name, trait_name),
                                      editor_name, direction)
                else:
                    raise TraitError(
                        "No editor with id = '%s' was found for "
                        "the '%s' metadata for the '%s' trait in the '%s' "
                        "context object." %
                        (editor_id, metadata, trait_name, name))

    def get_extended_value(self, name):
        """ Gets the current value of a specified extended trait name.
        """
        names = name.split(".")
        if len(names) > 1:
            value = self.context[names[0]]
            del names[0]
        else:
            value = self.context["object"]

        for name in names:
            value = getattr(value, name)

        return value

    def restore_prefs(self):
        """ Retrieves and restores any saved user preference information
        associated with the UI.
        """
        id = self.id
        if id != "":
            db = self.get_ui_db()
            if db is not None:
                try:
                    ui_prefs = db.get(id)
                    db.close()
                    return self.set_prefs(ui_prefs)
                except:
                    pass

        return None

    def set_prefs(self, prefs):
        """ Sets the values of user preferences for the UI.
        """
        if isinstance(prefs, dict):
            info = self.info
            for name in self._names:
                editor = getattr(info, name, None)
                if isinstance(editor, Editor) and (editor.ui is self):
                    editor_prefs = prefs.get(name)
                    if editor_prefs is not None:
                        editor.restore_prefs(editor_prefs)

            if self.key_bindings is not None:
                key_bindings = prefs.get("$")
                if key_bindings is not None:
                    self.key_bindings.merge(key_bindings)

            return prefs.get("")

        return None

    def save_prefs(self, prefs=None):
        """ Saves any user preference information associated with the UI.
        """
        if prefs is None:
            toolkit().save_window(self)
            return

        id = self.id
        if id != "":
            db = self.get_ui_db(mode="c")
            if db is not None:
                db[id] = self.get_prefs(prefs)
                db.close()

    def get_prefs(self, prefs=None):
        """ Gets the preferences to be saved for the user interface.
        """
        ui_prefs = {}
        if prefs is not None:
            ui_prefs[""] = prefs

        if self.key_bindings is not None:
            ui_prefs["$"] = self.key_bindings

        info = self.info
        for name in self._names:
            editor = getattr(info, name, None)
            if isinstance(editor, Editor) and (editor.ui is self):
                prefs = editor.save_prefs()
                if prefs is not None:
                    ui_prefs[name] = prefs

        return ui_prefs

    def get_ui_db(self, mode="r"):
        """ Returns a reference to the Traits UI preference database.
        """
        try:
            return shelve.open(
                os.path.join(traits_home(), "traits_ui"),
                flag=mode,
                protocol=-1,
            )
        except:
            return None

    def get_editors(self, name):
        """ Returns a list of editors for the given trait name.
        """
        return [editor for editor in self._editors if editor.name == name]

    def get_error_controls(self):
        """ Returns the list of editor error controls contained by the user
            interface.
        """
        controls = []
        for editor in self._editors:
            control = editor.get_error_control()
            if isinstance(control, list):
                controls.extend(control)
            else:
                controls.append(control)

        return controls

    def add_defined(self, method):
        """ Adds a Handler method to the list of methods to be called once the
            user interface has been constructed.
        """
        self._defined.append(method)

    def add_visible(self, visible_when, editor):
        """ Adds a conditionally enabled Editor object to the list of monitored
            'visible_when' objects.
        """
        try:
            self._visible.append((compile(visible_when, "<string>",
                                          "eval"), editor))
        except:
            pass
            # fixme: Log an error here...

    def add_enabled(self, enabled_when, editor):
        """ Adds a conditionally enabled Editor object to the list of monitored
            'enabled_when' objects.
        """
        try:
            self._enabled.append((compile(enabled_when, "<string>",
                                          "eval"), editor))
        except:
            pass
            # fixme: Log an error here...

    def add_checked(self, checked_when, editor):
        """ Adds a conditionally enabled (menu) Editor object to the list of
            monitored 'checked_when' objects.
        """
        try:
            self._checked.append((compile(checked_when, "<string>",
                                          "eval"), editor))
        except:
            pass
            # fixme: Log an error here...

    def do_undoable(self, action, *args, **kw):
        """ Performs an action that can be undone.
        """
        undoable = self._undoable
        try:
            if (undoable == -1) and (self.history is not None):
                self._undoable = self.history.now

            action(*args, **kw)
        finally:
            if undoable == -1:
                self._undoable = -1

    def route_event(self, event):
        """ Routes a "hooked" event to the correct handler method.
        """
        toolkit().route_event(self, event)

    def key_handler(self, event, skip=True):
        """ Handles key events.
        """
        key_bindings = self.key_bindings
        handled = (key_bindings is not None) and key_bindings.do(
            event, [], self.info, recursive=(self.parent is None))

        if (not handled) and (self.parent is not None):
            handled = self.parent.key_handler(event, False)

        if (not handled) and skip:
            toolkit().skip_event(event)

        return handled

    def evaluate(self, function, *args, **kw_args):
        """ Evaluates a specified function in the UI's **context**.
        """
        if function is None:
            return None

        if callable(function):
            return function(*args, **kw_args)

        context = self.context.copy()
        context["ui"] = self
        context["handler"] = self.handler
        return eval(function, globals(), context)(*args, **kw_args)

    def eval_when(self, when, result=True):
        """ Evaluates an expression in the UI's **context** and returns the
            result.
        """
        context = self._get_context(self.context)
        try:
            result = eval(when, globals(), context)
        except:
            from traitsui.api import raise_to_debug

            raise_to_debug()

        del context["ui"]

        return result

    def _get_context(self, context):
        """ Gets the context to use for evaluating an expression.
        """
        name = "object"
        n = len(context)
        if (n == 2) and ("handler" in context):
            for name, value in context.items():
                if name != "handler":
                    break
        elif n == 1:
            name = list(context.keys())[0]

        value = context.get(name)
        if value is not None:
            context2 = value.trait_get()
            context2.update(context)
        else:
            context2 = context.copy()

        context2["ui"] = self

        return context2

    def _evaluate_when(self):
        """ Set the 'visible', 'enabled', and 'checked' states for all Editors
            controlled by a 'visible_when', 'enabled_when' or 'checked_when'
            expression.
        """
        self._do_evaluate_when(at_init=False)

    def _do_evaluate_when(self, at_init=False):
        """ Set the 'visible', 'enabled', and 'checked' states for all Editors.

        This function does the job of _evaluate_when. We define it here to
        work around the traits dispatching mechanism that automatically
        determines the number of arguments of a notification method.

        :attr:`at_init` is set to true when this function is called the first
        time at initialization. In that case, we want to force the state of
        the items to be set (normally it is set only if it changes).
        """
        self._evaluate_condition(self._visible, "visible", at_init)
        self._evaluate_condition(self._enabled, "enabled", at_init)
        self._evaluate_condition(self._checked, "checked", at_init)

    def _evaluate_condition(self, conditions, trait, at_init=False):
        """ Evaluates a list of (eval, editor) pairs and sets a specified trait
        on each editor to reflect the Boolean value of the expression.

        1) All conditions are evaluated
        2) The elements whose condition evaluates to False are updated
        3) The elements whose condition evaluates to True are updated

        E.g., we first make invisible all elements for which 'visible_when'
        evaluates to False, and then we make visible the ones
        for which 'visible_when' is True. This avoids mutually exclusive
        elements to be visible at the same time, and thus making a dialog
        unnecessarily large.

        The state of an editor is updated only when it changes, unless
        at_init is set to True.

        Parameters
        ----------
        conditions : list of (str, Editor) tuple
            A list of tuples, each formed by 1) a string that contains a
            condition that evaluates to either True or False, and
            2) the editor whose state depends on the condition

        trait : str
            The trait that is set by the condition.
            Either 'visible, 'enabled', or 'checked'.

        at_init : bool
            If False, the state of an editor is set only when it changes
            (e.g., a visible element would not be updated to visible=True
            again). If True, the state is always updated (used at
            initialization).
        """

        context = self._get_context(self.context)

        # list of elements that should be activated
        activate = []
        # list of elements that should be de-activated
        deactivate = []

        for when, editor in conditions:
            try:
                cond_value = eval(when, globals(), context)
                editor_state = getattr(editor, trait)

                # add to update lists only if at_init is True (called on
                # initialization), or if the editor state has to change

                if cond_value and (at_init or not editor_state):
                    activate.append(editor)

                if not cond_value and (at_init or editor_state):
                    deactivate.append(editor)

            except Exception:
                # catch errors in the validate_when expression
                from traitsui.api import raise_to_debug

                raise_to_debug()

        # update the state of the editors
        for editor in deactivate:
            setattr(editor, trait, False)
        for editor in activate:
            setattr(editor, trait, True)

    def _get__groups(self):
        """ Returns the top-level Groups for the view (after resolving
        Includes. (Implements the **_groups** property.)
        """
        if self._groups_cache is None:
            shadow_group = self.view.content.get_shadow(self)
            self._groups_cache = shadow_group.get_content()
            for item in self._groups_cache:
                if isinstance(item, Item):
                    self._groups_cache = [
                        ShadowGroup(
                            shadow=Group(*self._groups_cache),
                            content=self._groups_cache,
                            groups=1,
                        )
                    ]
                    break
        return self._groups_cache

    # -- Property Implementations ---------------------------------------------

    def _get_key_bindings(self):
        if self._key_bindings is None:
            # create a new key_bindings instance lazily

            view, context = self.view, self.context
            if (view is None) or (context is None):
                return None

            # Get the KeyBindings object to use:
            values = list(context.values())
            key_bindings = view.key_bindings
            if key_bindings is None:
                from .key_bindings import KeyBindings

                self._key_bindings = KeyBindings(controllers=values)
            else:
                self._key_bindings = key_bindings.clone(controllers=values)

        return self._key_bindings

    # -- Traits Event Handlers ------------------------------------------------

    def _errors_changed(self, name, old, new):
        if self.parent:
            self.parent.errors = self.parent.errors - old + new

    def _parent_changed(self, name, old, new):
        if old is not None:
            old.errors -= self.errors
        if new is not None:
            new.errors += self.errors

    def _updated_changed(self):
        if self.rebuild is not None:
            toolkit().rebuild_ui(self)

    def _title_changed(self):
        if self.control is not None:
            toolkit().set_title(self)

    def _icon_changed(self):
        if self.control is not None:
            toolkit().set_icon(self)

    @observe("parent, view, context")
    def _pvc_changed(self, event):
        parent = self.parent
        if (parent is not None) and (self.key_bindings is not None):
            # If we don't have our own history, use our parent's:
            if self.history is None:
                self.history = parent.history

            # Link our KeyBindings object as a child of our parent's
            # KeyBindings object (if any):
            if parent.key_bindings is not None:
                parent.key_bindings.children.append(self.key_bindings)
Пример #5
0
class BruteForceOptimizerStep(ExperimentOptimizerStep):
    """ Optimize a set of simulation parameters to model the provided
    experiment using the grid search (brute force) approach.

    If sim_group_max_size is 0, the step creates 1 simulation grid around a
    simulation built to model each target experiment. if sim_group_max_size is
    a positive integer, all simulations for a target experiments are split into
    groups of size less or equal to sim_group_max_size.

    When a simulation grid is fully run, the cost of each simulation to the
    corresponding target experiment is computed using the cost function
    attribute. The cost data from each simulation grid is stored in the
    group_cost_data dict and combined into the step's cost_data once the
    simulation names are stripped.
    """
    # General step traits -----------------------------------------------------

    #: Type of the optimizer step
    optimizer_step_type = Constant(OPTIMIZER_STEP_TYPE)

    #: List of parameter objects to scan
    parameter_list = List(ParameterScanDescription)

    #: List of parameter names to scan
    scanned_param_names = Property(List(Str), depends_on="parameter_list[]")

    # SimulationGroup related traits ------------------------------------------

    #: List of simulation groups, scanning desired parameters, 1 per target exp
    # Built from start_point_simulation and scanned_params if not provided.
    simulation_groups = List(Instance(SimulationGroup))

    #: Cost function to minimize, one per simulation group
    group_cost_functions = Dict(Str, Callable)

    #: Maximum size for each of the simulation groups in the step
    # if the step needs a larger grid, it will be split into SimGroups of size
    # less or equal to this
    sim_group_max_size = Int

    #: Number of the next simulation group to run
    _next_group_to_run = Int(0)

    #: Local storage of the job_manager to run subsequent groups
    _job_manager = Instance(JobManager)

    #: Make the run call blocking?
    _wait_on_run = Bool

    # Run related traits ------------------------------------------------------

    # Total number of simulations involved in the optimization step
    size = Property(Int, depends_on="simulation_groups[]")

    #: Number of simulations already run
    size_run = Property(Int, depends_on="simulation_groups.size_run")

    #: Percentage of the optimizer that has already run
    percent_run = Property(Str, depends_on="size_run")

    # Output related traits ---------------------------------------------------

    #: Aggregation method to combine costs for all components & all experiments
    cost_agg_func = Enum("sum", "mean")

    #: Dict mapping each simulation group to its cost data.
    _group_cost_data = Dict

    #: Dict mapping each component to a list of the best simulations
    optimal_simulation_for_comp = Dict

    # Run related methods -----------------------------------------------------

    def run(self, job_manager, wait=False):
        """ Run optimization step by running all simulation groups it contains.
        """
        # Initialize run parameters
        super(BruteForceOptimizerStep, self).run(job_manager, wait=wait)
        if not self.simulation_groups:
            self.initialize_sim_group()

        first_group = self.simulation_groups[0]
        runner = first_group.run(job_manager, wait=wait)

        self._job_manager = job_manager
        self._next_group_to_run = 1
        self._wait_on_run = wait

        return runner

    def wait(self):
        """ Wait for currently known simulation groups to finish running.
        """
        for group in self.simulation_groups:
            msg = "Waiting for {} to finish...".format(group.name)
            logger.debug(msg)
            group.wait()

    def initialize_sim_group(self):
        """ Initialize simulation groups with one based on self attribute.

        Depending on the group_max_size, there may be multiple simulation
        groups to target a given experiment.
        """
        for exp, start_point_sim in zip(self.target_experiments,
                                        self.starting_point_simulations):
            name = "Grid {}_{}".format(exp.name, self.name)
            groups = param_scans_to_sim_group(
                name, self.parameter_list, start_point_sim,
                max_size=self.sim_group_max_size
            )
            self.simulation_groups.extend(groups)

    # Cost related methods ----------------------------------------------------

    def recompute_costs_for_weights(self, new_weights):
        """ Assume new weights for all cost functions.

        Also recompute costs for all groups if they have already been computed.
        """
        if not self.has_run:
            self.cost_func_kw["weights"] = new_weights
            return

        # Otherwise, recompute all costs data (using cached metrics stored in
        # cost functions:
        self.invalidate_group_cost_data()
        for group in self.simulation_groups:
            # Rebuild the simulations so that we can recover parameter values
            # for the cost data dataframe:
            if not group.simulations:
                group.initialize_simulations(use_output_cache=True)

            group_name = group.name
            cost_func = self.group_cost_functions[group_name]
            cost_func.weights = new_weights
            cost_data = cost_func.compute_costs()
            # Don't aggregate yet, to avoid triggering listeners until all
            # cost_data recomputed:
            self.update_cost_data_dict(group, cost_data, skip_aggregate=True)

        # Now we are ready to compute the step's cost_data:
        self.aggregate_cost_data()

    def compute_costs(self, sim_group, cost_function=None):
        """ Compute the costs of one of the SimulationGroups of the step.

        Also cache the cost_function for each sim_group, so that costs can be
        recomputed if weights are changed.

        Parameters
        ----------
        sim_group : SimulationGroup
            Group for which to compute costs.

        cost_function : Callable [OPTIONAL]
            Target cost function to use to compute costs. Optional: if a
            cost_function_type has been provided at step creation, and this is
            None, a cost_function will be created.
        """
        if cost_function is None:
            klass = ALL_COST_FUNCTIONS[self.cost_function_type]
            cost_function = klass(**self.cost_func_kw)

        target_exp = sim_group.center_point_simulation.source_experiment
        cost_data = cost_function(sim_group.simulations,
                                  target_exps=target_exp)
        self.group_cost_functions[sim_group.name] = cost_function
        self.update_cost_data_dict(sim_group, cost_data)

    def update_cost_data_dict(self, group, cost_data, skip_aggregate=False):
        """ Collect all cost_function cost data for all sim groups.

        Also aggregates all into the step's cost_data if the step has finished
        running. The step's cost data will aggregate data from all simulation
        groups, sum/average it over all components, and display the scanned
        parameters values along side with the aggregate cost.
        """
        if cost_data is None:
            return

        # Copy to avoid modifying the cost function object which has a hold on
        # the cost_data
        cost_data = cost_data.copy()
        simulations = group.simulations

        # Aggregate the cost function data
        df_agg_method = getattr(cost_data, self.cost_agg_func)
        cost_data[ALL_COST_COL_NAME] = df_agg_method(axis=1)

        # Add the values of the scanned parameters
        self.append_param_values(cost_data, simulations)

        # Collect the group's cost data with the rest of the data targeting the
        # same experiment if any:
        exp_name = group.center_point_simulation.source_experiment.name
        if exp_name in self._group_cost_data:
            existing = self._group_cost_data[exp_name]
            self._group_cost_data[exp_name] = pd.concat([existing, cost_data])
        else:
            self._group_cost_data[exp_name] = cost_data

        if self.has_run and not skip_aggregate:
            self.aggregate_cost_data()

    def invalidate_group_cost_data(self):
        """ Past cost_data are invalid. Delete them.
        """
        self._group_cost_data = {}

    def aggregate_cost_data(self):
        """ Aggregate cost data over all target experiment.

        The step's cost data will aggregate data from all simulation groups,
        sum/average it over all components, and display the scanned parameters
        values along side with the aggregate cost.
        """
        # Remove the column name from the final cost_data since there may be
        # more than 1 simulation for a given parameter setup, one per target
        # experiment:
        cost_data_list = [data.drop(SIM_COL_NAME, axis=1)
                          for data in self._group_cost_data.values()]
        average_cost_data = sum(cost_data_list)
        if self.cost_agg_func == "mean":
            average_cost_data /= len(self.target_experiments)

        self.cost_data = average_cost_data

    def append_param_values(self, costs_df, simulations):
        """ Evaluate parameters for provided sims and reset as cost DF index.
        """
        for param_name in self.scanned_param_names:
            expr = "sim.{}".format(param_name)
            costs_df[param_name] = [eval(expr, {"sim": sim})
                                    for sim in simulations]
            first_val = costs_df[param_name][0]
            if isinstance(first_val, UnitScalar):
                costs_df[param_name] = costs_df[param_name].apply(float)
            elif is_squeezable(first_val):
                # FIXME: WHEN DOES THIS HAPPEN?
                costs_df[param_name] = costs_df[param_name].apply(float)
            elif is_repeating_array(first_val):
                # This can happen when a parameter is a slice of an array:
                # replace with the first value if all the same because we can't
                # index with an array (unhashable).
                costs_df[param_name] = costs_df[param_name].apply(
                    lambda x: x[0]
                )

        costs_df.reset_index(inplace=True)
        costs_df.set_index(self.scanned_param_names, inplace=True)

    # Optimal simulation methods ----------------------------------------------

    def update_optimal_simulation_for_comp(self):
        """ Extract the best simulation for each product component.
        """
        best_simulations = defaultdict(list)
        for comp in self.target_components:
            for group_cost_data in self._group_cost_data.values():
                data = group_cost_data[comp]
                try:
                    idx = data.argmin(axis=0)
                    sim_name = group_cost_data.loc[idx, SIM_COL_NAME]
                    sim = self._get_sim_from_name(sim_name)
                    best_simulations[comp].append(sim)
                except Exception as e:
                    msg = "Failing to find the simulation with minimal cost " \
                          "for component {}. Data was {}. (Exception was {})"
                    logger.error(msg.format(comp, data, e))

        self.optimal_simulation_for_comp = best_simulations

    def get_optimal_sims(self, exp_name, num_sims):
        """ Collect optimal num_sims simulations matching specific experiment.
        """
        if len(self.cost_data) == 0:
            return []

        # Make sure we are not trying to extract more optimal simulations that
        # the total number of available simulations (for a given experiment)

        sorted_data = self.cost_data.sort_values(by=ALL_COST_COL_NAME)
        optim_sim_idx = sorted_data.index[:num_sims]
        # This assumes that self.cost_data and elements of
        # self._group_cost_data are indexed on the same columns:
        group_data = self._group_cost_data[exp_name]
        sim_names = group_data.loc[optim_sim_idx, SIM_COL_NAME].tolist()
        return [self._get_sim_from_name(name) for name in sim_names]

    # Private interface -------------------------------------------------------

    def _get_sim_from_name(self, sim_name):
        """ Find a simulation ran in the step in the simulation sim groups.

        Raises
        ------
        ValueError
            If the simulation isn't found.
        """
        pattern = "Sim (\d+)_(.+)"
        match = re.match(pattern, sim_name)
        target_sim_num, target_group_name = match.groups()
        group = self._get_group_from_name(target_group_name)
        try:
            sim = group.get_simulation(int(target_sim_num))
            if sim.name != sim_name:
                msg = "Logical error: the simulation's name isn't what was " \
                      "expected!"
                logger.exception(msg)
                raise ValueError(msg)

            return sim
        except (IndexError, AssertionError) as e:
            msg = "Simulation with name {} not found in step's simulation " \
                  "groups. Error was {}."
            msg = msg.format(sim_name, e)
            logger.error(msg)
            raise ValueError(msg)

    def _get_group_from_name(self, group_name):
        """ Return the simulation group with provided name.
        """
        for group in self.simulation_groups:
            if group.name.startswith(group_name):
                return group

        msg = "SimulationGroup with name {} not found in step's groups. " \
              "Known names are {}"
        known_group_names = [group.name for group in self.simulation_groups]
        msg = msg.format(group_name, known_group_names)
        logger.error(msg)
        raise ValueError(msg)

    def _get_step_has_run(self):
        if not self.simulation_groups:
            return False
        return all([group.has_run for group in self.simulation_groups])

    # Traits listeners --------------------------------------------------------

    @on_trait_change("simulation_groups:has_run")
    def optimize_costs(self, sim_group, attr_name, group_has_run):
        self.has_run = self._get_step_has_run()
        if group_has_run:
            msg = "Group {} has finished running: updating costs."
            msg = msg.format(sim_group.name)
            logger.info(msg)

            self.compute_costs(sim_group)
            if self.has_run:
                self.update_optimal_simulation_for_comp()
            else:
                self._run_next_sim_group()

            # Save memory by throwing away simulations: they can be rebuilt
            # from the simulation diffs.
            sim_group.release_simulation_list()
            self.data_updated = True

    def _run_next_sim_group(self):
        """ A simGroup has finished running: run the next one.
        """
        next_group = self.simulation_groups[self._next_group_to_run]
        msg = "Now submitting {} to run...".format(next_group.name)
        logger.debug(msg)
        next_group.run(self._job_manager, wait=self._wait_on_run)
        self._next_group_to_run += 1

    # Traits property getters -------------------------------------------------

    def _get_size(self):
        return sum([group.size for group in self.simulation_groups])

    def _get_size_run(self):
        return sum([group.size_run for group in self.simulation_groups])

    def _get_percent_run(self):
        if self.size:
            percent_run = self.size_run / self.size * 100.
        else:
            percent_run = np.nan

        return "{:.2f} %".format(percent_run)

    def _get_scanned_param_names(self):
        step_params = []
        for param in self.parameter_list:
            p_name = param.name
            parallel_params = hasattr(param, "parallel_parameters") and \
                len(param.parallel_parameters) > 0
            if parallel_params:
                step_params.extend([p.name for p in param.parallel_parameters])

            step_params.append(p_name)

        return step_params

    # Traits initialization methods -------------------------------------------

    def _cost_data_default(self):
        cols = self.target_components + [ALL_COST_COL_NAME]
        data = {name: [] for name in cols}
        return pd.DataFrame(data, index=[])

    def _sim_group_max_size_default(self):
        preferences = get_preferences()
        return preferences.optimizer_preferences.optimizer_step_chunk_size
Пример #6
0
class TimeSamples( SamplesGenerator ):
    """
    Container for time data in `*.h5` format.
    
    This class loads measured data from h5 files and
    and provides information about this data.
    It also serves as an interface where the data can be accessed
    (e.g. for use in a block chain) via the :meth:`result` generator.
    """

    #: Full name of the .h5 file with data.
    name = File(filter=['*.h5'], 
        desc="name of data file")

    #: Basename of the .h5 file with data, is set automatically.
    basename = Property( depends_on = 'name', #filter=['*.h5'], 
        desc="basename of data file")
    
    #: Calibration data, instance of :class:`~acoular.calib.Calib` class, optional .
    calib = Trait( Calib, 
        desc="Calibration data")
    
    #: Number of channels, is set automatically / read from file.
    numchannels = CLong(0, 
        desc="number of input channels")

    #: Number of time data samples, is set automatically / read from file.
    numsamples = CLong(0, 
        desc="number of samples")

    #: The time data as array of floats with dimension (numsamples, numchannels).
    data = Any( transient = True, 
        desc="the actual time data array")

    #: HDF5 file object
    h5f = Instance(H5FileBase, transient = True)
    
    #: Provides metadata stored in HDF5 file object
    metadata = Dict(
        desc="metadata contained in .h5 file")
    
    # Checksum over first data entries of all channels
    _datachecksum = Property()
    
    # internal identifier
    digest = Property( depends_on = ['basename', 'calib.digest', '_datachecksum'])

    def _get__datachecksum( self ):
        return self.data[0,:].sum()
    
    @cached_property
    def _get_digest( self ):
        return digest(self)
    
    @cached_property
    def _get_basename( self ):
        return path.splitext(path.basename(self.name))[0]
    
    @on_trait_change('basename')
    def load_data( self ):
        """ 
        Open the .h5 file and set attributes.
        """
        if not path.isfile(self.name):
            # no file there
            self.numsamples = 0
            self.numchannels = 0
            self.sample_freq = 0
            raise IOError("No such file: %s" % self.name)
        if self.h5f != None:
            try:
                self.h5f.close()
            except IOError:
                pass
        file = _get_h5file_class()
        self.h5f = file(self.name)
        self.load_timedata()
        self.load_metadata()

    def load_timedata( self ):
        """ loads timedata from .h5 file. Only for internal use. """
        self.data = self.h5f.get_data_by_reference('time_data')    
        self.sample_freq = self.h5f.get_node_attribute(self.data,'sample_freq')
        (self.numsamples, self.numchannels) = self.data.shape

    def load_metadata( self ):
        """ loads metadata from .h5 file. Only for internal use. """
        self.metadata = {}
        if '/metadata' in self.h5f:
            for nodename, nodedata in self.h5f.get_child_nodes('/metadata'):
                self.metadata[nodename] = nodedata

    def result(self, num=128):
        """
        Python generator that yields the output block-wise.
                
        Parameters
        ----------
        num : integer, defaults to 128
            This parameter defines the size of the blocks to be yielded
            (i.e. the number of samples per block) .
        
        Returns
        -------
        Samples in blocks of shape (num, numchannels). 
            The last block may be shorter than num.
        """
        if self.numsamples == 0:
            raise IOError("no samples available")
        self._datachecksum # trigger checksum calculation
        i = 0
        if self.calib:
            if self.calib.num_mics == self.numchannels:
                cal_factor = self.calib.data[newaxis]
            else:
                raise ValueError("calibration data not compatible: %i, %i" % \
                            (self.calib.num_mics, self.numchannels))
            while i < self.numsamples:
                yield self.data[i:i+num]*cal_factor
                i += num
        else:
            while i < self.numsamples:
                yield self.data[i:i+num]
                i += num
Пример #7
0
class FlowPeaksOp(HasStrictTraits):
    """
    This module uses the **flowPeaks** algorithm to assign events to clusters in
    an unsupervised manner.
    
    Call :meth:`estimate` to compute the clusters.
      
    Calling :meth:`apply` creates a new categorical metadata variable 
    named ``name``, with possible values ``{name}_1`` .... ``name_n`` where 
    ``n`` is the number of clusters estimated.
    
    The same model may not be appropriate for different subsets of the data set.
    If this is the case, you can use the :attr:`by` attribute to specify 
    metadata by which to aggregate the data before estimating (and applying) 
    a model.  The number of clusters is a model parameter and it may vary in 
    each subset. 

    Attributes
    ----------
    name : Str
        The operation name; determines the name of the new metadata column
        
    channels : List(Str)
        The channels to apply the clustering algorithm to.

    scale : Dict(Str : Enum("linear", "logicle", "log"))
        Re-scale the data in the specified channels before fitting.  If a 
        channel is in :attr:`channels` but not in :attr:`scale`, the current 
        package-wide default (set with :func:`set_default_scale`) is used.
    
    by : List(Str)
        A list of metadata attributes to aggregate the data before estimating
        the model.  For example, if the experiment has two pieces of metadata,
        ``Time`` and ``Dox``, setting ``by = ["Time", "Dox"]`` will fit the model 
        separately to each subset of the data with a unique combination of
        ``Time`` and ``Dox``.
        
    h : Float (default = 1.5)
        A scalar value by which to scale the covariance matrices of the 
        underlying density function.  (See ``Notes``, below, for more details.)
        
    h0 : Float (default = 1.0)
        A scalar value by which to smooth the covariance matrices of the
        underlying density function.  (See ``Notes``, below, for more details.)
        
    tol : Float (default = 0.5)
        How readily should clusters be merged?  Must be between 0 and 1.
        See ``Notes``, below, for more details.
        
    merge_dist : Float (default = 5)
        How far apart can clusters be before they are merged?  This is
        a unit-free scalar, and is approximately the maximum number of
        k-means clusters between peaks. 
        
    find_outliers : Bool (default = False)
        Should the algorithm use an extra step to identify outliers?
        
        .. note::
            I have disabled this code until I can try to make it faster.
        
    Notes
    -----
    
    This algorithm uses kmeans to find a large number of clusters, then 
    hierarchically merges those clusters.  Thus, the user does not need to
    specify the number of clusters in advance, and it can find non-convex
    clusters.  It also operates in an arbitrary number of dimensions.
    
    The merging happens in two steps.  First, the cluster centroids are used
    to estimate an underlying density function.  Then, the local maxima of
    the density function are found using a numerical optimization starting from
    each centroid, and k-means clusters that converge to the same local maximum
    are merged.  Finally, these clusters-of-clusters are merged if their local 
    maxima are (a) close enough, and (b) the density function between them is 
    smooth enough.  Thus, the final assignment of each event depends on the 
    k-means cluster it ends up in, and which cluster-of-clusters that k-means 
    centroid is assigned to.
    
    There are a lot of parameters that affect this process.  The k-means
    clustering is pretty robust (though somewhat sensitive to the number of 
    clusters, which is currently not exposed in the API.) The most important
    are exposed as attributes of the :class:`FlowPeaksOp` class.  These include:
    
     - :attr:`h`, :attr:`h0`: sometimes the density function is too "rough" to 
         find good local maxima.  These parameters smooth it out by widening the
         covariance matrices.  Increasing :attr:`h` makes the density rougher; 
         increasing :attr:`h0` makes it smoother.
              
    - :attr:`tol`: How smooth does the density function have to be between two 
        density maxima to merge them?  Must be between 0 and 1.
           
    - :attr:`merge_dist`: How close must two maxima be to merge them?  This 
        value is a unit-free scalar, and is approximately the number of
        k-means clusters between the two maxima.
        
    For details and a theoretical justification, see [1]_
    
    References
    ----------
    
    .. [1] Ge, Yongchao and Sealfon, Stuart C.  "flowPeaks: a fast unsupervised
       clustering for flow cytometry data via K-means and density peak finding" 
       Bioinformatics (2012) 28 (15): 2052-2058.         
  
    Examples
    --------
    
    .. plot::
        :context: close-figs
        
        Make a little data set.
    
        >>> import cytoflow as flow
        >>> import_op = flow.ImportOp()
        >>> import_op.tubes = [flow.Tube(file = "Plate01/RFP_Well_A3.fcs",
        ...                              conditions = {'Dox' : 10.0}),
        ...                    flow.Tube(file = "Plate01/CFP_Well_A4.fcs",
        ...                              conditions = {'Dox' : 1.0})]
        >>> import_op.conditions = {'Dox' : 'float'}
        >>> ex = import_op.apply()
    
    Create and parameterize the operation.
    
    .. plot::
        :context: close-figs
        
        >>> fp_op = flow.FlowPeaksOp(name = 'Flow',
        ...                          channels = ['V2-A', 'Y2-A'],
        ...                          scale = {'V2-A' : 'log',
        ...                                   'Y2-A' : 'log'},
        ...                          h0 = 3)
        
    Estimate the clusters
    
    .. plot::
        :context: close-figs
        
        >>> fp_op.estimate(ex)
        
    Plot a diagnostic view of the underlying density
    
    .. plot::
        :context: close-figs
        
        >>> fp_op.default_view(density = True).plot(ex)

    Apply the gate
    
    .. plot::
        :context: close-figs
        
        >>> ex2 = fp_op.apply(ex)

    Plot a diagnostic view with the event assignments
    
    .. plot::
        :context: close-figs
        
        >>> fp_op.default_view().plot(ex2)
        

    """

    id = Constant('edu.mit.synbio.cytoflow.operations.flowpeaks')
    friendly_id = Constant("FlowPeaks Clustering")

    name = CStr()
    channels = List(Str)
    scale = Dict(Str, util.ScaleEnum)
    by = List(Str)
    #     find_outliers = Bool(False)

    # parameters that control estimation, with sensible defaults
    h = util.PositiveFloat(1.5, allow_zero=False)
    h0 = util.PositiveFloat(1, allow_zero=False)
    tol = util.PositiveFloat(0.5, allow_zero=False)
    merge_dist = util.PositiveFloat(5, allow_zero=False)

    # parameters that control outlier selection, with sensible defaults

    _kmeans = Dict(Any,
                   Instance(sklearn.cluster.MiniBatchKMeans),
                   transient=True)
    _normals = Dict(Any, List(Function), transient=True)
    _density = Dict(Any, Function, transient=True)
    _peaks = Dict(Any, List(Array), transient=True)
    _cluster_peak = Dict(Any, List,
                         transient=True)  # kmeans cluster idx --> peak idx
    _cluster_group = Dict(Any, List,
                          transient=True)  # kmeans cluster idx --> group idx
    _scale = Dict(Str, Instance(util.IScale), transient=True)

    def estimate(self, experiment, subset=None):
        """
        Estimate the k-means clusters, then hierarchically merge them.
        
        Parameters
        ----------
        experiment : Experiment
            The :class:`.Experiment` to use to estimate the k-means clusters
            
        subset : str (default = None)
            A Python expression that specifies a subset of the data in 
            ``experiment`` to use to parameterize the operation.
        """

        if experiment is None:
            raise util.CytoflowOpError('experiment', "No experiment specified")

        if len(self.channels) == 0:
            raise util.CytoflowOpError('channels',
                                       "Must set at least one channel")

        for c in self.channels:
            if c not in experiment.data:
                raise util.CytoflowOpError(
                    'channels',
                    "Channel {0} not found in the experiment".format(c))

        for c in self.scale:
            if c not in self.channels:
                raise util.CytoflowOpError(
                    'scale', "Scale set for channel {0}, but it isn't "
                    "in the experiment".format(c))

        for b in self.by:
            if b not in experiment.conditions:
                raise util.CytoflowOpError(
                    'by', "Aggregation metadata {} not found, "
                    "must be one of {}".format(b, experiment.conditions))

        if subset:
            try:
                experiment = experiment.query(subset)
            except:
                raise util.CytoflowOpError(
                    'subset', "Subset string '{0}' isn't valid".format(subset))

            if len(experiment) == 0:
                raise util.CytoflowOpError(
                    'subset',
                    "Subset string '{0}' returned no events".format(subset))

        if self.by:
            groupby = experiment.data.groupby(self.by)
        else:
            # use a lambda expression to return a group that contains
            # all the events
            groupby = experiment.data.groupby(lambda _: True)

        # get the scale. estimate the scale params for the ENTIRE data set,
        # not subsets we get from groupby().  And we need to save it so that
        # the data is transformed the same way when we apply()
        for c in self.channels:
            if c in self.scale:
                self._scale[c] = util.scale_factory(self.scale[c],
                                                    experiment,
                                                    channel=c)
#                 if self.scale[c] == 'log':
#                     self._scale[c].mode = 'mask'
            else:
                self._scale[c] = util.scale_factory(util.get_default_scale(),
                                                    experiment,
                                                    channel=c)

        for data_group, data_subset in groupby:
            if len(data_subset) == 0:
                raise util.CytoflowOpError(
                    'by', "Group {} had no data".format(data_group))
            x = data_subset.loc[:, self.channels[:]]
            for c in self.channels:
                x[c] = self._scale[c](x[c])

            # drop data that isn't in the scale range
            for c in self.channels:
                x = x[~(np.isnan(x[c]))]
            x = x.values

            #### choose the number of clusters and fit the kmeans
            num_clusters = [
                util.num_hist_bins(x[:, c]) for c in range(len(self.channels))
            ]
            num_clusters = np.ceil(np.median(num_clusters))
            num_clusters = int(num_clusters)

            self._kmeans[data_group] = kmeans = \
                sklearn.cluster.MiniBatchKMeans(n_clusters = num_clusters,
                                                random_state = 0)

            kmeans.fit(x)
            x_labels = kmeans.predict(x)
            d = len(self.channels)

            #### use the kmeans centroids to parameterize a finite gaussian
            #### mixture model which estimates the density function

            d = len(self.channels)
            s0 = np.zeros([d, d])
            for j in range(d):
                r = x[d].max() - x[d].min()
                s0[j, j] = (r / (num_clusters**(1. / d)))**0.5

            means = []
            weights = []
            normals = []

            for k in range(num_clusters):
                xk = x[x_labels == k]
                num_k = np.sum(x_labels == k)
                weight_k = num_k / len(x_labels)
                mu = xk.mean(axis=0)
                means.append(mu)
                s = np.cov(xk, rowvar=False)

                el = num_k / (num_clusters + num_k)
                s_smooth = el * self.h * s + (1.0 - el) * self.h0 * s0

                n = scipy.stats.multivariate_normal(mean=mu, cov=s_smooth)
                weights.append(weight_k)
                normals.append(lambda x, n=n: n.pdf(x))

            self._normals[data_group] = normals
            self._density[
                data_group] = density = lambda x, weights=weights, normals=normals: np.sum(
                    [w * n(x) for w, n in zip(weights, normals)], axis=0)

            ### use optimization on the finite gmm to find the local peak for
            ### each kmeans cluster
            peaks = []
            peak_clusters = []  # peak idx --> list of clusters

            min_mu = [np.inf] * len(self.channels)
            max_mu = [-1.0 * np.inf] * len(self.channels)

            for k in range(num_clusters):
                mu = means[k]
                for ci in range(len(self.channels)):
                    if mu[ci] < min_mu[ci]:
                        min_mu[ci] = mu[ci]
                    if mu[ci] > max_mu[ci]:
                        max_mu[ci] = mu[ci]

            for k in range(num_clusters):
                mu = means[k]
                f = lambda x: -1.0 * density(x)

                res = scipy.optimize.minimize(f,
                                              mu,
                                              method="CG",
                                              options={'gtol': 1e-3})

                if not res.success:
                    warn(
                        "Peak finding failed for cluster {}: {}".format(
                            k, res.message), util.CytoflowWarning)


#                 ### The peak-searching algorithm from the paper.  works fine,
#                 ### but slow!  we get similar results with the conjugate gradient
#                 ### optimization method from scipy

#                 x0 = x = means[k]
#                 k0 = k
#                 b = beta_max[k] / 10.0
#                 Nsuc = 0
#                 n = 0
#
#                 while(n < 1000):
# #                     df = scipy.misc.derivative(density, x, 1e-6)
#                     df = statsmodels.tools.numdiff.approx_fprime(x, density)
#                     if np.linalg.norm(df) < 1e-3:
#                         break
#
#                     y = x + b * df / np.linalg.norm(df)
#                     if density(y) <= density(x):
#                         Nsuc = 0
#                         b = b / 2.0
#                         continue
#
#                     Nsuc += 1
#                     if Nsuc >= 2:
#                         b = min(2*b, beta_max[k])
#
#                     ky = kmeans.predict(y[np.newaxis, :])[0]
#                     if ky == k:
#                         x = y
#                     else:
#                         k = ky
#                         b = beta_max[k] / 10.0
#                         mu = means[k]
#                         if density(mu) > density(y):
#                             x = mu
#                         else:
#                             x = y
#
#                     n += 1

                merged = False
                for pi, p in enumerate(peaks):
                    # TODO - this probably only works for scaled measurements
                    if np.linalg.norm(p - res.x) < (1e-2):
                        peak_clusters[pi].append(k)
                        merged = True
                        break

                if not merged:
                    peak_clusters.append([k])
                    peaks.append(res.x)

            self._peaks[data_group] = peaks

            ### merge peaks that are sufficiently close

            groups = [[x] for x in range(len(peaks))]
            peak_groups = [x for x in range(len(peaks))
                           ]  # peak idx --> group idx

            def max_tol(x, y):
                f = lambda a: density(a[np.newaxis, :])
                #                 lx = kmeans.predict(x[np.newaxis, :])[0]
                #                 ly = kmeans.predict(y[np.newaxis, :])[0]
                n = len(x)
                n_scale = 1

                #                 n_scale = np.sqrt(((nx + ny) / 2.0) / (n / num_clusters))

                def tol(t):
                    zt = x + t * (y - x)
                    fhat_zt = f(x) + t * (f(y) - f(x))
                    return -1.0 * abs((f(zt) - fhat_zt) / fhat_zt) * n_scale

                res = scipy.optimize.minimize_scalar(tol,
                                                     bounds=[0, 1],
                                                     method='Bounded')

                if res.status != 0:
                    raise util.CytoflowOpError(
                        None,
                        "tol optimization failed for {}, {}".format(x, y))
                return -1.0 * res.fun

            def nearest_neighbor_dist(k):
                min_dist = np.inf

                for i in range(num_clusters):
                    if i == k:
                        continue
                    dist = np.linalg.norm(means[k] - means[i])
                    if dist < min_dist:
                        min_dist = dist

                return min_dist

            sk = [nearest_neighbor_dist(x) for x in range(num_clusters)]

            def s(x):
                k = kmeans.predict(x[np.newaxis, :])[0]
                return sk[k]

            def can_merge(g, h):
                for pg in g:
                    for ph in h:
                        vg = peaks[pg]
                        vh = peaks[ph]
                        dist_gh = np.linalg.norm(vg - vh)

                        if max_tol(vg, vh) < self.tol and dist_gh / (
                                s(vg) + s(vh)) <= self.merge_dist:
                            return True

                return False

            while True:
                if len(groups) == 1:
                    break

                # find closest mergable groups
                min_dist = np.inf
                for gi in range(len(groups)):
                    g = groups[gi]

                    for hi in range(gi + 1, len(groups)):
                        h = groups[hi]

                        if can_merge(g, h):
                            dist_gh = np.inf
                            for pg in g:
                                vg = peaks[pg]
                                for ph in h:
                                    vh = peaks[ph]
                                    #                                     print("vg {} vh {}".format(vg, vh))
                                    dist_gh = min(dist_gh,
                                                  np.linalg.norm(vg - vh))

                            if dist_gh < min_dist:
                                min_gi = gi
                                min_hi = hi
                                min_dist = dist_gh

                if min_dist == np.inf:
                    break

                # merge the groups
                groups[min_gi].extend(groups[min_hi])
                for g in groups[min_hi]:
                    peak_groups[g] = min_gi
                del groups[min_hi]

            cluster_group = [0] * num_clusters
            cluster_peaks = [0] * num_clusters

            for gi, g in enumerate(groups):
                for p in g:
                    for cluster in peak_clusters[p]:
                        cluster_group[cluster] = gi
                        cluster_peaks[cluster] = p

            self._cluster_peak[data_group] = cluster_peaks
            self._cluster_group[data_group] = cluster_group

    def apply(self, experiment):
        """
        Assign events to a cluster.
        
        Assigns each event to one of the k-means centroids from :meth:`estimate`,
        then groups together events in the same cluster hierarchy.
        
        Parameters
        ----------
        experiment : Experiment
            the :class:`.Experiment` to apply the gate to.
            
        Returns
        -------
        Experiment
            A new :class:`.Experiment` with the gate applied to it.  
            TODO - document the extra statistics
        """

        if experiment is None:
            raise util.CytoflowOpError('experiment', "No experiment specified")

        # make sure name got set!
        if not self.name:
            raise util.CytoflowOpError(
                'name', "You have to set the gate's name "
                "before applying it!")

        if self.name in experiment.data.columns:
            raise util.CytoflowOpError(
                'name',
                "Experiment already has a column named {0}".format(self.name))

        if len(self.channels) == 0:
            raise util.CytoflowOpError('channels',
                                       "Must set at least one channel")

        if not self._peaks:
            raise util.CytoflowOpError(
                None, "No model found.  Did you forget to "
                "call estimate()?")

        for c in self.channels:
            if c not in experiment.data:
                raise util.CytoflowOpError(
                    'channels',
                    "Channel {0} not found in the experiment".format(c))

        for c in self.scale:
            if c not in self.channels:
                raise util.CytoflowOpError(
                    'scale', "Scale set for channel {0}, but it isn't "
                    "in the experiment".format(c))

        for b in self.by:
            if b not in experiment.conditions:
                raise util.CytoflowOpError(
                    'by', "Aggregation metadata {} not found, "
                    "must be one of {}".format(b, experiment.conditions))

        if self.by:
            groupby = experiment.data.groupby(self.by)
        else:
            # use a lambda expression to return a group that contains
            # all the events
            groupby = experiment.data.groupby(lambda _: True)

        event_assignments = pd.Series(["{}_None".format(self.name)] *
                                      len(experiment),
                                      dtype="object")

        # make the statistics
        #         clusters = [x + 1 for x in range(self.num_clusters)]
        #
        #         idx = pd.MultiIndex.from_product([experiment[x].unique() for x in self.by] + [clusters] + [self.channels],
        #                                          names = list(self.by) + ["Cluster"] + ["Channel"])
        #         centers_stat = pd.Series(index = idx, dtype = np.dtype(object)).sort_index()

        for group, data_subset in groupby:
            if len(data_subset) == 0:
                raise util.CytoflowOpError(
                    'by', "Group {} had no data".format(group))

            if group not in self._kmeans:
                raise util.CytoflowOpError(
                    'by', "Group {} not found in the estimated "
                    "model.  Do you need to re-run estimate()?".format(group))

            x = data_subset.loc[:, self.channels[:]]

            for c in self.channels:
                x[c] = self._scale[c](x[c])

            # which values are missing?

            x_na = pd.Series([False] * len(x))
            for c in self.channels:
                x_na[np.isnan(x[c]).values] = True

            x = x.values
            x_na = x_na.values
            group_idx = groupby.groups[group]

            kmeans = self._kmeans[group]

            predicted_km = np.full(len(x), -1, "int")
            predicted_km[~x_na] = kmeans.predict(x[~x_na])

            groups = np.asarray(self._cluster_group[group])
            predicted_group = np.full(len(x), -1, "int")
            predicted_group[~x_na] = groups[predicted_km[~x_na]]

            # outlier detection code.  this is disabled for the moment
            # because it is really slow.

            #             num_groups = len(set(groups))
            #             if self.find_outliers:
            #                 density = self._density[group]
            #                 max_d = [-1.0 * np.inf] * num_groups
            #
            #                 for xi in range(len(x)):
            #                     if x_na[xi]:
            #                         continue
            #
            #                     x_c = predicted_group[xi]
            #                     d_x_c = density(x[xi])
            #                     if d_x_c > max_d[x_c]:
            #                         max_d[x_c] = d_x_c
            #
            #                 group_density = [None] * num_groups
            #                 group_weight = [0.0] * num_groups
            #
            #                 for c in range(num_groups):
            #                     num_c = np.sum(predicted_group == c)
            #                     clusters = np.argwhere(groups == c).flatten()
            #
            #                     normals = []
            #                     weights = []
            #                     for k in range(len(clusters)):
            #                         num_k = np.sum(predicted_km == k)
            #                         weight_k = num_k / num_c
            #                         group_weight[c] += num_k / len(x)
            #                         weights.append(weight_k)
            #                         normals.append(self._normals[group][k])
            #
            #                     group_density[c] = lambda x, weights = weights, normals = normals: np.sum([w * n(x) for w, n in zip(weights, normals)], axis = 0)
            #
            #                 for xi in range(len(x)):
            #                     if x_na[xi]:
            #                         continue
            #
            #                     x_c = predicted_group[xi]
            #
            #                     if density(x[xi]) / max_d[x_c] < 0.01:
            #                         predicted_group[xi] = -1
            #                         continue
            #
            #                     sum_d = 0
            #                     for c in set(groups):
            #                         sum_d += group_weight[c] * group_density[c](x[xi])
            #
            #                     if group_weight[x_c] * group_density[x_c](x[xi]) / sum_d < 0.8:
            #                         predicted_group[xi] = -1

            #
            #                     max_d = -1.0 * np.inf
            #                     for x_c in x[predicted_group == c]:
            #                         x_c_d = density(x_c)
            #                         if x_c_d > max_d:
            #                             max_d = x_c_d
            #
            #                     for i in range(len(x)):
            #                         if predicted_group[i] == c and density(x[i]) / max_d <= 0.01:
            #                             predicted_group[i] = -1
            #
            #

            predicted_str = pd.Series(["(none)"] * len(predicted_group))
            for c in range(len(self._cluster_group[group])):
                predicted_str[predicted_group == c] = "{0}_{1}".format(
                    self.name, c + 1)
            predicted_str[predicted_group == -1] = "{0}_None".format(self.name)
            predicted_str.index = group_idx

            event_assignments.iloc[group_idx] = predicted_str

        new_experiment = experiment.clone()
        new_experiment.add_condition(self.name, "category", event_assignments)

        #         new_experiment.statistics[(self.name, "centers")] = pd.to_numeric(centers_stat)

        new_experiment.history.append(
            self.clone_traits(transient=lambda _: True))
        return new_experiment

    def default_view(self, **kwargs):
        """
        Returns a diagnostic plot of the Gaussian mixture model.
        
        Parameters
        ----------
        channels : List(Str)
            Which channels to plot?  Must be contain either one or two channels.
            
        scale : List({'linear', 'log', 'logicle'})
            How to scale the channels before plotting them
            
        density : bool
            Should we plot a scatterplot or the estimated density function?
         
        Returns
        -------
        IView
            an IView, call :meth:`plot` to see the diagnostic plot.
        """
        channels = kwargs.pop('channels', self.channels)
        scale = kwargs.pop('scale', self.scale)
        density = kwargs.pop('density', False)

        for c in channels:
            if c not in self.channels:
                raise util.CytoflowViewError(
                    'channels',
                    "Channel {} isn't in the operation's channels".format(c))

        for s in scale:
            if s not in self.channels:
                raise util.CytoflowViewError(
                    'channels',
                    "Channel {} isn't in the operation's channels".format(s))

        for c in channels:
            if c not in scale:
                scale[c] = util.get_default_scale()

        if len(channels) == 0:
            raise util.CytoflowViewError(
                'channels',
                "Must specify at least one channel for a default view")
        elif len(channels) == 1:
            v = FlowPeaks1DView(op=self)
            v.trait_set(channel=channels[0],
                        scale=scale[channels[0]],
                        **kwargs)
            return v

        elif len(channels) == 2:
            if density:
                v = FlowPeaks2DDensityView(op=self)
                v.trait_set(xchannel=channels[0],
                            ychannel=channels[1],
                            xscale=scale[channels[0]],
                            yscale=scale[channels[1]],
                            **kwargs)
                return v

            else:
                v = FlowPeaks2DView(op=self)
                v.trait_set(xchannel=channels[0],
                            ychannel=channels[1],
                            xscale=scale[channels[0]],
                            yscale=scale[channels[1]],
                            **kwargs)
                return v
        else:
            raise util.CytoflowViewError(
                None,
                "Can't specify more than two channels for a default view")
Пример #8
0
class ToolPalette(Widget):

    tools = List()

    id_tool_map = Dict()

    tool_id_to_button_map = Dict()

    button_size = Tuple((25, 25), Int, Int)

    is_realized = Bool(False)

    tool_listeners = Dict()

    # Maps a button id to its tool id.
    button_tool_map = Dict()

    # ------------------------------------------------------------------------
    # 'object' interface.
    # ------------------------------------------------------------------------

    def __init__(self, parent, **traits):
        """ Creates a new tool palette. """

        # Base class constructor.
        super(ToolPalette, self).__init__(**traits)

        # Create the toolkit-specific control that represents the widget.
        self.control = self._create_control(parent)

        return

    # ------------------------------------------------------------------------
    # ToolPalette interface.
    # ------------------------------------------------------------------------

    def add_tool(self, label, bmp, kind, tooltip, longtip):
        """ Add a tool with the specified properties to the palette.

        Return an id that can be used to reference this tool in the future.
        """

        return 1

    def toggle_tool(self, id, checked):
        """ Toggle the tool identified by 'id' to the 'checked' state.

        If the button is a toggle or radio button, the button will be checked
        if the 'checked' parameter is True; unchecked otherwise.  If the button
        is a standard button, this method is a NOP.
        """

    def enable_tool(self, id, enabled):
        """ Enable or disable the tool identified by 'id'. """

    def on_tool_event(self, id, callback):
        """ Register a callback for events on the tool identified by 'id'. """

    def realize(self):
        """ Realize the control so that it can be displayed. """

    def get_tool_state(self, id):
        """ Get the toggle state of the tool identified by 'id'. """
        state = 0

        return state

    # ------------------------------------------------------------------------
    # Private interface.
    # ------------------------------------------------------------------------

    def _create_control(self, parent):
        return None
Пример #9
0
class PCAPluginOp(PluginOpMixin, PCAOp):
    handler_factory = Callable(PCAHandler)

    channels_list = List(_Channel, estimate=True)
    channels = List(Str, transient=True)
    scale = Dict(Str, util.ScaleEnum, transient=True)
    by = List(Str, estimate=True)
    num_components = util.PositiveCInt(2, allow_zero=False, estimate=True)
    whiten = Bool(False, estimate=True)

    @on_trait_change('channels_list_items, channels_list.+', post_init=True)
    def _channels_changed(self, obj, name, old, new):
        self.changed = (Changed.ESTIMATE, ('channels_list',
                                           self.channels_list))

    # bits to support the subset editor

    subset_list = List(ISubset, estimate=True)
    subset = Property(Str, depends_on="subset_list.str")

    # MAGIC - returns the value of the "subset" Property, above
    def _get_subset(self):
        return " and ".join(
            [subset.str for subset in self.subset_list if subset.str])

    @on_trait_change('subset_list.str')
    def _subset_changed(self, obj, name, old, new):
        self.changed = (Changed.ESTIMATE, ('subset_list', self.subset_list))

    def estimate(self, experiment):
        for i, channel_i in enumerate(self.channels_list):
            for j, channel_j in enumerate(self.channels_list):
                if channel_i.channel == channel_j.channel and i != j:
                    raise util.CytoflowOpError(
                        "Channel {0} is included more than once".format(
                            channel_i.channel))

        self.channels = []
        self.scale = {}
        for channel in self.channels_list:
            self.channels.append(channel.channel)
            self.scale[channel.channel] = channel.scale

        super().estimate(experiment, subset=self.subset)
        self.changed = (Changed.ESTIMATE_RESULT, self)

    def apply(self, experiment):
        for i, channel_i in enumerate(self.channels_list):
            for j, channel_j in enumerate(self.channels_list):
                if channel_i.channel == channel_j.channel and i != j:
                    raise util.CytoflowOpError(
                        "Channel {0} is included more than once".format(
                            channel_i.channel))

        self.channels = []
        self.scale = {}
        for channel in self.channels_list:
            self.channels.append(channel.channel)
            self.scale[channel.channel] = channel.scale

        return super().apply(experiment)

    def clear_estimate(self):
        self._pca.clear()
        self.changed = (Changed.ESTIMATE_RESULT, self)

    def get_notebook_code(self, idx):
        op = PCAOp()
        op.copy_traits(self, op.copyable_trait_names())

        for channel in self.channels_list:
            op.channels.append(channel.channel)
            op.scale[channel.channel] = channel.scale

        return dedent("""
        op_{idx} = {repr}
        
        op_{idx}.estimate(ex_{prev_idx}{subset})
        ex_{idx} = op_{idx}.apply(ex_{prev_idx})
        """.format(repr=repr(op),
                   idx=idx,
                   prev_idx=idx - 1,
                   subset=", subset = " +
                   repr(self.subset) if self.subset else ""))
Пример #10
0
class SolutionView(HasTraits):
    python_console_cmds = Dict()
    # we need to doubleup on Lists to store the psuedo absolutes separately
    # without rewriting everything
    """
  logging_v : toggle logging for velocity files
  directory_name_v : location and name of velocity files
  logging_p : toggle logging for position files
  directory_name_p : location and name of velocity files
  """
    plot_history_max = Int(1000)
    logging_v = Bool(False)
    directory_name_v = File

    logging_p = Bool(False)
    directory_name_p = File

    lats_psuedo_abs = List()
    lngs_psuedo_abs = List()
    alts_psuedo_abs = List()

    table = List()
    dops_table = List()
    pos_table = List()
    vel_table = List()

    rtk_pos_note = Str(
        "It is necessary to enter the \"Surveyed Position\" settings for the base station in order to view the RTK Positions in this tab."
    )

    plot = Instance(Plot)
    plot_data = Instance(ArrayPlotData)
    # Store plots we care about for legend

    running = Bool(True)
    zoomall = Bool(False)
    position_centered = Bool(False)

    clear_button = SVGButton(label='',
                             tooltip='Clear',
                             filename=os.path.join(determine_path(), 'images',
                                                   'iconic', 'x.svg'),
                             width=16,
                             height=16)
    zoomall_button = SVGButton(label='',
                               tooltip='Zoom All',
                               toggle=True,
                               filename=os.path.join(determine_path(),
                                                     'images', 'iconic',
                                                     'fullscreen.svg'),
                               width=16,
                               height=16)
    center_button = SVGButton(label='',
                              tooltip='Center on Solution',
                              toggle=True,
                              filename=os.path.join(determine_path(), 'images',
                                                    'iconic', 'target.svg'),
                              width=16,
                              height=16)
    paused_button = SVGButton(label='',
                              tooltip='Pause',
                              toggle_tooltip='Run',
                              toggle=True,
                              filename=os.path.join(determine_path(), 'images',
                                                    'iconic', 'pause.svg'),
                              toggle_filename=os.path.join(
                                  determine_path(), 'images', 'iconic',
                                  'play.svg'),
                              width=16,
                              height=16)

    traits_view = View(
        HSplit(
            VGroup(
                Item('table',
                     style='readonly',
                     editor=TabularEditor(adapter=SimpleAdapter()),
                     show_label=False,
                     width=0.3),
                Item('rtk_pos_note',
                     show_label=False,
                     resizable=True,
                     editor=MultilineTextEditor(TextEditor(multi_line=True)),
                     style='readonly',
                     width=0.3,
                     height=-40),
            ),
            VGroup(
                HGroup(
                    Item('paused_button', show_label=False),
                    Item('clear_button', show_label=False),
                    Item('zoomall_button', show_label=False),
                    Item('center_button', show_label=False),
                ),
                Item('plot',
                     show_label=False,
                     editor=ComponentEditor(bgcolor=(0.8, 0.8, 0.8))),
            )))

    def _zoomall_button_fired(self):
        self.zoomall = not self.zoomall

    def _center_button_fired(self):
        self.position_centered = not self.position_centered

    def _paused_button_fired(self):
        self.running = not self.running

    def _reset_remove_current(self):
        self.plot_data.set_data('cur_lat_spp', [])
        self.plot_data.set_data('cur_lng_spp', [])
        self.plot_data.set_data('cur_alt_spp', [])
        self.plot_data.set_data('cur_lat_dgnss', [])
        self.plot_data.set_data('cur_lng_dgnss', [])
        self.plot_data.set_data('cur_alt_dgnss', [])
        self.plot_data.set_data('cur_lat_float', [])
        self.plot_data.set_data('cur_lng_float', [])
        self.plot_data.set_data('cur_alt_float', [])
        self.plot_data.set_data('cur_lat_fixed', [])
        self.plot_data.set_data('cur_lng_fixed', [])
        self.plot_data.set_data('cur_alt_fixed', [])

    def _clear_button_fired(self):
        self.tows = np.empty(self.plot_history_max)
        self.lats = np.empty(self.plot_history_max)
        self.lngs = np.empty(self.plot_history_max)
        self.alts = np.empty(self.plot_history_max)
        self.modes = np.empty(self.plot_history_max)
        self.plot_data.set_data('lat_spp', [])
        self.plot_data.set_data('lng_spp', [])
        self.plot_data.set_data('alt_spp', [])
        self.plot_data.set_data('lat_dgnss', [])
        self.plot_data.set_data('lng_dgnss', [])
        self.plot_data.set_data('alt_dgnss', [])
        self.plot_data.set_data('lat_float', [])
        self.plot_data.set_data('lng_float', [])
        self.plot_data.set_data('alt_float', [])
        self.plot_data.set_data('lat_fixed', [])
        self.plot_data.set_data('lng_fixed', [])
        self.plot_data.set_data('alt_fixed', [])
        self._reset_remove_current()

    def _pos_llh_callback(self, sbp_msg, **metadata):
        # Updating an ArrayPlotData isn't thread safe (see chaco issue #9), so
        # actually perform the update in the UI thread.
        if self.running:
            GUI.invoke_later(self.pos_llh_callback, sbp_msg)

    def update_table(self):
        self._table_list = list(self.table_spp.items())

    def auto_survey(self):
        if self.last_soln.flags != 0:
            self.latitude_list.append(self.last_soln.lat)
            self.longitude_list.append(self.last_soln.lon)
            self.altitude_list.append(self.last_soln.height)
        if len(self.latitude_list) > 1000:
            self.latitude_list = self.latitude_list[-1000:]
            self.longitude_list = self.longitude_list[-1000:]
            self.altitude_list = self.altitude_list[-1000:]
        if len(self.latitude_list) != 0:
            self.latitude = sum(self.latitude_list) / len(self.latitude_list)
            self.altitude = sum(self.altitude_list) / len(self.latitude_list)
            self.longitude = sum(self.longitude_list) / len(self.latitude_list)

    def pos_llh_callback(self, sbp_msg, **metadata):
        if sbp_msg.msg_type == SBP_MSG_POS_LLH_DEP_A:
            soln = MsgPosLLHDepA(sbp_msg)
        else:
            soln = MsgPosLLH(sbp_msg)
        self.last_soln = soln

        self.last_pos_mode = get_mode(soln)
        pos_table = []
        soln.h_accuracy *= 1e-3
        soln.v_accuracy *= 1e-3

        tow = soln.tow * 1e-3
        if self.nsec is not None:
            tow += self.nsec * 1e-9

        # Return the best estimate of my local and receiver time in convenient
        # format that allows changing precision of the seconds
        ((tloc, secloc), (tgps, secgps)) = log_time_strings(self.week, tow)

        if (self.directory_name_p == ''):
            filepath_p = time.strftime("position_log_%Y%m%d-%H%M%S.csv")
        else:
            filepath_p = os.path.join(
                self.directory_name_p,
                time.strftime("position_log_%Y%m%d-%H%M%S.csv"))

        if self.logging_p == False:
            self.log_file = None

        if self.logging_p:
            if self.log_file is None:
                self.log_file = sopen(filepath_p, 'w')
                self.log_file.write(
                    "pc_time,gps_time,tow(msec),latitude(degrees),longitude(degrees),altitude(meters),"
                    "h_accuracy(meters),v_accuracy(meters),n_sats,flags\n")
            log_str_gps = ""
            if tgps != "" and secgps != 0:
                log_str_gps = "{0}:{1:06.6f}".format(tgps, float(secgps))
            self.log_file.write(
                '%s,%s,%.3f,%.10f,%.10f,%.4f,%.4f,%.4f,%d,%d\n' %
                ("{0}:{1:06.6f}".format(tloc, float(secloc)), log_str_gps, tow,
                 soln.lat, soln.lon, soln.height, soln.h_accuracy,
                 soln.v_accuracy, soln.n_sats, soln.flags))
            self.log_file.flush()

        if self.last_pos_mode == 0:
            pos_table.append(('GPS Time', EMPTY_STR))
            pos_table.append(('GPS Week', EMPTY_STR))
            pos_table.append(('GPS TOW', EMPTY_STR))
            pos_table.append(('Num. Signals', EMPTY_STR))
            pos_table.append(('Lat', EMPTY_STR))
            pos_table.append(('Lng', EMPTY_STR))
            pos_table.append(('Height', EMPTY_STR))
            pos_table.append(('h_accuracy', EMPTY_STR))
            pos_table.append(('v_accuracy', EMPTY_STR))
        else:
            self.last_stime_update = time.time()
            if self.week is not None:
                pos_table.append(
                    ('GPS Time', "{0}:{1:06.3f}".format(tgps, float(secgps))))
                pos_table.append(('GPS Week', str(self.week)))
            pos_table.append(('GPS TOW', "{:.3f}".format(tow)))
            pos_table.append(('Num. Sats', soln.n_sats))
            pos_table.append(('Lat', soln.lat))
            pos_table.append(('Lng', soln.lon))
            pos_table.append(('Height', soln.height))
            pos_table.append(('h_accuracy', soln.h_accuracy))
            pos_table.append(('v_accuracy', soln.v_accuracy))

        pos_table.append(('Pos Flags', '0x%03x' % soln.flags))
        pos_table.append(('Pos Fix Mode', mode_dict[self.last_pos_mode]))

        self.auto_survey()

        # setup_plot variables
        self.lats[1:] = self.lats[:-1]
        self.lngs[1:] = self.lngs[:-1]
        self.alts[1:] = self.alts[:-1]
        self.tows[1:] = self.tows[:-1]
        self.modes[1:] = self.modes[:-1]

        self.lats[0] = soln.lat
        self.lngs[0] = soln.lon
        self.alts[0] = soln.height
        self.tows[0] = soln.tow
        self.modes[0] = self.last_pos_mode

        self.lats = self.lats[-self.plot_history_max:]
        self.lngs = self.lngs[-self.plot_history_max:]
        self.alts = self.alts[-self.plot_history_max:]
        self.tows = self.tows[-self.plot_history_max:]
        self.modes = self.modes[-self.plot_history_max:]

        # SPP
        spp_indexer, dgnss_indexer, float_indexer, fixed_indexer = None, None, None, None
        if np.any(self.modes):
            spp_indexer = (self.modes == SPP_MODE)
            dgnss_indexer = (self.modes == DGNSS_MODE)
            float_indexer = (self.modes == FLOAT_MODE)
            fixed_indexer = (self.modes == FIXED_MODE)

            # make sure that there is at least one true in indexer before setting
            if any(spp_indexer):
                self.plot_data.set_data('lat_spp', self.lats[spp_indexer])
                self.plot_data.set_data('lng_spp', self.lngs[spp_indexer])
                self.plot_data.set_data('alt_spp', self.alts[spp_indexer])
            if any(dgnss_indexer):
                self.plot_data.set_data('lat_dgnss', self.lats[dgnss_indexer])
                self.plot_data.set_data('lng_dgnss', self.lngs[dgnss_indexer])
                self.plot_data.set_data('alt_dgnss', self.alts[dgnss_indexer])
            if any(float_indexer):
                self.plot_data.set_data('lat_float', self.lats[float_indexer])
                self.plot_data.set_data('lng_float', self.lngs[float_indexer])
                self.plot_data.set_data('alt_float', self.alts[float_indexer])
            if any(fixed_indexer):
                self.plot_data.set_data('lat_fixed', self.lats[fixed_indexer])
                self.plot_data.set_data('lng_fixed', self.lngs[fixed_indexer])
                self.plot_data.set_data('alt_fixed', self.alts[fixed_indexer])

            # update our "current solution" icon
            if self.last_pos_mode == SPP_MODE:
                self._reset_remove_current()
                self.plot_data.set_data('cur_lat_spp', [soln.lat])
                self.plot_data.set_data('cur_lng_spp', [soln.lon])
            elif self.last_pos_mode == DGNSS_MODE:
                self._reset_remove_current()
                self.plot_data.set_data('cur_lat_dgnss', [soln.lat])
                self.plot_data.set_data('cur_lng_dgnss', [soln.lon])
            elif self.last_pos_mode == FLOAT_MODE:
                self._reset_remove_current()
                self.plot_data.set_data('cur_lat_float', [soln.lat])
                self.plot_data.set_data('cur_lng_float', [soln.lon])
            elif self.last_pos_mode == FIXED_MODE:
                self._reset_remove_current()
                self.plot_data.set_data('cur_lat_fixed', [soln.lat])
                self.plot_data.set_data('cur_lng_fixed', [soln.lon])
            else:
                pass

        # set-up table variables
        self.pos_table = pos_table
        self.table = self.pos_table + self.vel_table + self.dops_table

        # TODO: figure out how to center the graph now that we have two separate messages
        # when we selectively send only SPP, the centering function won't work anymore

        if not self.zoomall and self.position_centered:
            d = (self.plot.index_range.high - self.plot.index_range.low) / 2.
            self.plot.index_range.set_bounds(soln.lon - d, soln.lon + d)
            d = (self.plot.value_range.high - self.plot.value_range.low) / 2.
            self.plot.value_range.set_bounds(soln.lat - d, soln.lat + d)
        if self.zoomall:
            plot_square_axes(
                self.plot, ('lng_spp', 'lng_dgnss', 'lng_float', 'lng_fixed'),
                ('lat_spp', 'lat_dgnss', 'lat_float', 'lat_fixed'))

    def dops_callback(self, sbp_msg, **metadata):
        flags = 0
        if sbp_msg.msg_type == SBP_MSG_DOPS_DEP_A:
            dops = MsgDopsDepA(sbp_msg)
            flags = 1
        else:
            dops = MsgDops(sbp_msg)
            flags = dops.flags
        if flags != 0:
            self.dops_table = [('PDOP', '%.1f' % (dops.pdop * 0.01)),
                               ('GDOP', '%.1f' % (dops.gdop * 0.01)),
                               ('TDOP', '%.1f' % (dops.tdop * 0.01)),
                               ('HDOP', '%.1f' % (dops.hdop * 0.01)),
                               ('VDOP', '%.1f' % (dops.vdop * 0.01))]
        else:
            self.dops_table = [('PDOP', EMPTY_STR), ('GDOP', EMPTY_STR),
                               ('TDOP', EMPTY_STR), ('HDOP', EMPTY_STR),
                               ('VDOP', EMPTY_STR)]

        self.dops_table.append(('DOPS Flags', '0x%03x' % flags))
        self.table = self.pos_table + self.vel_table + self.dops_table

    def vel_ned_callback(self, sbp_msg, **metadata):
        flags = 0
        if sbp_msg.msg_type == SBP_MSG_VEL_NED_DEP_A:
            vel_ned = MsgVelNEDDepA(sbp_msg)
            flags = 1
        else:
            vel_ned = MsgVelNED(sbp_msg)
            flags = vel_ned.flags
        tow = vel_ned.tow * 1e-3
        if self.nsec is not None:
            tow += self.nsec * 1e-9

        ((tloc, secloc), (tgps, secgps)) = log_time_strings(self.week, tow)

        if self.directory_name_v == '':
            filepath_v = time.strftime("velocity_log_%Y%m%d-%H%M%S.csv")
        else:
            filepath_v = os.path.join(
                self.directory_name_v,
                time.strftime("velocity_log_%Y%m%d-%H%M%S.csv"))

        if self.logging_v == False:
            self.vel_log_file = None

        if self.logging_v:
            if self.vel_log_file is None:
                self.vel_log_file = sopen(filepath_v, 'w')
                self.vel_log_file.write(
                    'pc_time,gps_time,tow,north(m/s),east(m/s),down(m/s),speed(m/s),flags,num_signals\n'
                )
            log_str_gps = ''
            if tgps != "" and secgps != 0:
                log_str_gps = "{0}:{1:06.6f}".format(tgps, float(secgps))
            self.vel_log_file.write(
                '%s,%s,%.3f,%.6f,%.6f,%.6f,%.6f,%d,%d\n' %
                ("{0}:{1:06.6f}".format(tloc, float(secloc)), log_str_gps, tow,
                 vel_ned.n * 1e-3, vel_ned.e * 1e-3, vel_ned.d * 1e-3,
                 math.sqrt(vel_ned.n * vel_ned.n + vel_ned.e * vel_ned.e) *
                 1e-3, flags, vel_ned.n_sats))
            self.vel_log_file.flush()
        if flags != 0:
            self.vel_table = [
                ('Vel. N', '% 8.4f' % (vel_ned.n * 1e-3)),
                ('Vel. E', '% 8.4f' % (vel_ned.e * 1e-3)),
                ('Vel. D', '% 8.4f' % (vel_ned.d * 1e-3)),
            ]
        else:
            self.vel_table = [
                ('Vel. N', EMPTY_STR),
                ('Vel. E', EMPTY_STR),
                ('Vel. D', EMPTY_STR),
            ]
        self.vel_table.append(('Vel Flags', '0x%03x' % flags))
        self.table = self.pos_table + self.vel_table + self.dops_table

    def gps_time_callback(self, sbp_msg, **metadata):
        if sbp_msg.msg_type == SBP_MSG_GPS_TIME_DEP_A:
            time_msg = MsgGPSTimeDepA(sbp_msg)
            flags = 1
        elif sbp_msg.msg_type == SBP_MSG_GPS_TIME:
            time_msg = MsgGPSTime(sbp_msg)
            flags = time_msg.flags
            if flags != 0:
                self.week = time_msg.wn
                self.nsec = time_msg.ns

    def __init__(self, link, dirname=''):
        super(SolutionView, self).__init__()

        self.lats = np.zeros(self.plot_history_max)
        self.lngs = np.zeros(self.plot_history_max)
        self.alts = np.zeros(self.plot_history_max)
        self.tows = np.zeros(self.plot_history_max)
        self.modes = np.zeros(self.plot_history_max)
        self.log_file = None
        self.directory_name_v = dirname
        self.directory_name_p = dirname
        self.vel_log_file = None
        self.last_stime_update = 0
        self.last_soln = None

        self.counter = 0
        self.latitude_list = []
        self.longitude_list = []
        self.altitude_list = []
        self.altitude = 0
        self.longitude = 0
        self.latitude = 0
        self.last_pos_mode = 0

        self.plot_data = ArrayPlotData(lat_spp=[],
                                       lng_spp=[],
                                       alt_spp=[],
                                       cur_lat_spp=[],
                                       cur_lng_spp=[],
                                       lat_dgnss=[],
                                       lng_dgnss=[],
                                       alt_dgnss=[],
                                       cur_lat_dgnss=[],
                                       cur_lng_dgnss=[],
                                       lat_float=[],
                                       lng_float=[],
                                       alt_float=[],
                                       cur_lat_float=[],
                                       cur_lng_float=[],
                                       lat_fixed=[],
                                       lng_fixed=[],
                                       alt_fixed=[],
                                       cur_lat_fixed=[],
                                       cur_lng_fixed=[])
        self.plot = Plot(self.plot_data)

        # 1000 point buffer
        self.plot.plot(('lng_spp', 'lat_spp'),
                       type='line',
                       line_width=0.1,
                       name='',
                       color=color_dict[SPP_MODE])
        self.plot.plot(('lng_spp', 'lat_spp'),
                       type='scatter',
                       name='',
                       color=color_dict[SPP_MODE],
                       marker='dot',
                       line_width=0.0,
                       marker_size=1.0)
        self.plot.plot(('lng_dgnss', 'lat_dgnss'),
                       type='line',
                       line_width=0.1,
                       name='',
                       color=color_dict[DGNSS_MODE])
        self.plot.plot(('lng_dgnss', 'lat_dgnss'),
                       type='scatter',
                       name='',
                       color=color_dict[DGNSS_MODE],
                       marker='dot',
                       line_width=0.0,
                       marker_size=1.0)
        self.plot.plot(('lng_float', 'lat_float'),
                       type='line',
                       line_width=0.1,
                       name='',
                       color=color_dict[FLOAT_MODE])
        self.plot.plot(('lng_float', 'lat_float'),
                       type='scatter',
                       name='',
                       color=color_dict[FLOAT_MODE],
                       marker='dot',
                       line_width=0.0,
                       marker_size=1.0)
        self.plot.plot(('lng_fixed', 'lat_fixed'),
                       type='line',
                       line_width=0.1,
                       name='',
                       color=color_dict[FIXED_MODE])
        self.plot.plot(('lng_fixed', 'lat_fixed'),
                       type='scatter',
                       name='',
                       color=color_dict[FIXED_MODE],
                       marker='dot',
                       line_width=0.0,
                       marker_size=1.0)
        # current values
        spp = self.plot.plot(('cur_lng_spp', 'cur_lat_spp'),
                             type='scatter',
                             name=mode_dict[SPP_MODE],
                             color=color_dict[SPP_MODE],
                             marker='plus',
                             line_width=1.5,
                             marker_size=5.0)
        dgnss = self.plot.plot(('cur_lng_dgnss', 'cur_lat_dgnss'),
                               type='scatter',
                               name=mode_dict[DGNSS_MODE],
                               color=color_dict[DGNSS_MODE],
                               marker='plus',
                               line_width=1.5,
                               marker_size=5.0)
        rtkfloat = self.plot.plot(('cur_lng_float', 'cur_lat_float'),
                                  type='scatter',
                                  name=mode_dict[FLOAT_MODE],
                                  color=color_dict[FLOAT_MODE],
                                  marker='plus',
                                  line_width=1.5,
                                  marker_size=5.0)
        rtkfix = self.plot.plot(('cur_lng_fixed', 'cur_lat_fixed'),
                                type='scatter',
                                name=mode_dict[FIXED_MODE],
                                color=color_dict[FIXED_MODE],
                                marker='plus',
                                line_width=1.5,
                                marker_size=5.0)
        plot_labels = ['SPP', 'DGPS', "RTK float", "RTK fixed"]
        plots_legend = dict(
            list(zip(plot_labels, [spp, dgnss, rtkfloat, rtkfix])))
        self.plot.legend.plots = plots_legend
        self.plot.legend.labels = plot_labels  # sets order
        self.plot.legend.visible = True

        self.plot.index_axis.tick_label_position = 'inside'
        self.plot.index_axis.tick_label_color = 'gray'
        self.plot.index_axis.tick_color = 'gray'
        self.plot.index_axis.title = 'Longitude (degrees)'
        self.plot.index_axis.title_spacing = 5
        self.plot.value_axis.tick_label_position = 'inside'
        self.plot.value_axis.tick_label_color = 'gray'
        self.plot.value_axis.tick_color = 'gray'
        self.plot.value_axis.title = 'Latitude (degrees)'
        self.plot.value_axis.title_spacing = 5
        self.plot.padding = (25, 25, 25, 25)

        self.plot.tools.append(PanTool(self.plot))
        zt = ZoomTool(self.plot,
                      zoom_factor=1.1,
                      tool_mode="box",
                      always_on=False)
        self.plot.overlays.append(zt)

        self.link = link
        self.link.add_callback(self._pos_llh_callback,
                               [SBP_MSG_POS_LLH_DEP_A, SBP_MSG_POS_LLH])
        self.link.add_callback(self.vel_ned_callback,
                               [SBP_MSG_VEL_NED_DEP_A, SBP_MSG_VEL_NED])
        self.link.add_callback(self.dops_callback,
                               [SBP_MSG_DOPS_DEP_A, SBP_MSG_DOPS])
        self.link.add_callback(self.gps_time_callback,
                               [SBP_MSG_GPS_TIME_DEP_A, SBP_MSG_GPS_TIME])

        self.week = None
        self.nsec = 0

        self.python_console_cmds = {
            'solution': self,
        }
Пример #11
0
class CompositeGridModel(GridModel):
    """ A CompositeGridModel is a model whose underlying data is
    a collection of other grid models. """

    # The models this model is comprised of.
    data = List(Instance(GridModel))

    # The rows in the model.
    rows = Union(None, List(Instance(GridRow)))

    # The cached data indexes.
    _data_index = Dict()

    # ------------------------------------------------------------------------
    # 'object' interface.
    # ------------------------------------------------------------------------
    def __init__(self, **traits):
        """ Create a CompositeGridModel object. """

        # Base class constructor
        super().__init__(**traits)

        self._row_count = None

    # ------------------------------------------------------------------------
    # 'GridModel' interface.
    # ------------------------------------------------------------------------
    def get_column_count(self):
        """ Return the number of columns for this table. """

        # for the composite grid model, this is simply the sum of the
        # column counts for the underlying models
        count = 0
        for model in self.data:
            count += model.get_column_count()

        return count

    def get_column_name(self, index):
        """ Return the name of the column specified by the
        (zero-based) index. """

        model, new_index = self._resolve_column_index(index)

        return model.get_column_name(new_index)

    def get_column_size(self, index):
        """ Return the size in pixels of the column indexed by col.
            A value of -1 or None means use the default. """

        model, new_index = self._resolve_column_index(index)
        return model.get_column_size(new_index)

    def get_cols_drag_value(self, cols):
        """ Return the value to use when the specified columns are dragged or
        copied and pasted. cols is a list of column indexes. """

        values = []
        for col in cols:
            model, real_col = self._resolve_column_index(col)
            values.append(model.get_cols_drag_value([real_col]))

        return values

    def get_cols_selection_value(self, cols):
        """ Return the value to use when the specified cols are selected.
        This value should be enough to specify to other listeners what is
        going on in the grid. rows is a list of row indexes. """

        return self.get_cols_drag_value(self, cols)

    def get_column_context_menu(self, col):
        """ Return a MenuManager object that will generate the appropriate
        context menu for this column."""

        model, new_index = self._resolve_column_index(col)

        return model.get_column_context_menu(new_index)

    def sort_by_column(self, col, reverse=False):
        """ Sort model data by the column indexed by col. The reverse flag
        indicates that the sort should be done in reverse. """
        pass

    def is_column_read_only(self, index):
        """ Return True if the column specified by the zero-based index
        is read-only. """
        model, new_index = self._resolve_column_index(index)

        return model.is_column_read_only(new_index)

    def get_row_count(self):
        """ Return the number of rows for this table. """

        # see if we've already calculated the row_count
        if self._row_count is None:
            row_count = 0
            # return the maximum rows of any of the contained models
            for model in self.data:
                rows = model.get_row_count()
                if rows > row_count:
                    row_count = rows

            # save the result for next time
            self._row_count = row_count

        return self._row_count

    def get_row_name(self, index):
        """ Return the name of the row specified by the
        (zero-based) index. """

        label = None
        # if the rows list exists then grab the label from there...
        if self.rows is not None:
            if len(self.rows) > index:
                label = self.rows[index].label
        # ... otherwise generate it from the zero-based index.
        else:
            label = str(index + 1)

        return label

    def get_rows_drag_value(self, rows):
        """ Return the value to use when the specified rows are dragged or
        copied and pasted. rows is a list of row indexes. """
        row_values = []
        for rindex in rows:
            row = []
            for model in self.data:
                new_data = model.get_rows_drag_value([rindex])
                # if it's a list then we assume that it represents more than
                # one column's worth of values
                if isinstance(new_data, list):
                    row.extend(new_data)
                else:
                    row.append(new_data)

            # now save our new row value
            row_values.append(row)

        return row_values

    def is_row_read_only(self, index):
        """ Return True if the row specified by the zero-based index
        is read-only. """

        read_only = False
        if self.rows is not None and len(self.rows) > index:
            read_only = self.rows[index].read_only

        return read_only

    def get_type(self, row, col):
        """ Return the type of the value stored in the table at (row, col). """
        model, new_col = self._resolve_column_index(col)

        return model.get_type(row, new_col)

    def get_value(self, row, col):
        """ Return the value stored in the table at (row, col). """
        model, new_col = self._resolve_column_index(col)

        return model.get_value(row, new_col)

    def get_cell_selection_value(self, row, col):
        """ Return the value stored in the table at (row, col). """
        model, new_col = self._resolve_column_index(col)

        return model.get_cell_selection_value(row, new_col)

    def resolve_selection(self, selection_list):
        """ Returns a list of (row, col) grid-cell coordinates that
        correspond to the objects in selection_list. For each coordinate, if
        the row is -1 it indicates that the entire column is selected. Likewise
        coordinates with a column of -1 indicate an entire row that is
        selected. Note that the objects in selection_list are
        model-specific. """

        coords = []
        for selection in selection_list:
            # we have to look through each of the models in order
            # for the selected object
            for model in self.data:
                cells = model.resolve_selection([selection])
                # we know this model found the object if cells comes back
                # non-empty
                if cells is not None and len(cells) > 0:
                    coords.extend(cells)
                    break

        return coords

    # fixme: this context menu stuff is going in here for now, but it
    # seems like this is really more of a view piece than a model piece.
    # this is how the tree control does it, however, so we're duplicating
    # that here.
    def get_cell_context_menu(self, row, col):
        """ Return a MenuManager object that will generate the appropriate
        context menu for this cell."""

        model, new_col = self._resolve_column_index(col)

        return model.get_cell_context_menu(row, new_col)

    def is_cell_empty(self, row, col):
        """ Returns True if the cell at (row, col) has a None value,
        False otherwise."""
        model, new_col = self._resolve_column_index(col)

        if model is None:
            return True

        else:
            return model.is_cell_empty(row, new_col)

    def is_cell_editable(self, row, col):
        """ Returns True if the cell at (row, col) is editable,
        False otherwise. """
        model, new_col = self._resolve_column_index(col)

        return model.is_cell_editable(row, new_col)

    def is_cell_read_only(self, row, col):
        """ Returns True if the cell at (row, col) is not editable,
        False otherwise. """

        model, new_col = self._resolve_column_index(col)

        return model.is_cell_read_only(row, new_col)

    def get_cell_bg_color(self, row, col):
        """ Return a wxColour object specifying what the background color
            of the specified cell should be. """
        model, new_col = self._resolve_column_index(col)

        return model.get_cell_bg_color(row, new_col)

    def get_cell_text_color(self, row, col):
        """ Return a wxColour object specifying what the text color
            of the specified cell should be. """
        model, new_col = self._resolve_column_index(col)

        return model.get_cell_text_color(row, new_col)

    def get_cell_font(self, row, col):
        """ Return a wxFont object specifying what the font
            of the specified cell should be. """
        model, new_col = self._resolve_column_index(col)

        return model.get_cell_font(row, new_col)

    def get_cell_halignment(self, row, col):
        """ Return a string specifying what the horizontal alignment
            of the specified cell should be.

            Return 'left' for left alignment, 'right' for right alignment,
            or 'center' for center alignment. """
        model, new_col = self._resolve_column_index(col)

        return model.get_cell_halignment(row, new_col)

    def get_cell_valignment(self, row, col):
        """ Return a string specifying what the vertical alignment
            of the specified cell should be.

            Return 'top' for top alignment, 'bottom' for bottom alignment,
            or 'center' for center alignment. """
        model, new_col = self._resolve_column_index(col)

        return model.get_cell_valignment(row, new_col)

    # ------------------------------------------------------------------------
    # protected 'GridModel' interface.
    # ------------------------------------------------------------------------
    def _delete_rows(self, pos, num_rows):
        """ Implementation method for delete_rows. Should return the
        number of rows that were deleted. """

        for model in self.data:
            model._delete_rows(pos, num_rows)

        return num_rows

    def _insert_rows(self, pos, num_rows):
        """ Implementation method for insert_rows. Should return the
        number of rows that were inserted. """

        for model in self.data:
            model._insert_rows(pos, num_rows)

        return num_rows

    def _set_value(self, row, col, value):
        """ Implementation method for set_value. Should return the
        number of rows, if any, that were appended. """

        model, new_col = self._resolve_column_index(col)
        model._set_value(row, new_col, value)
        return 0

    # ------------------------------------------------------------------------
    # private interface
    # ------------------------------------------------------------------------

    def _resolve_column_index(self, index):
        """ Resolves a column index into the correct model and adjusted
        index. Returns the target model and the corrected index. """

        real_index = index
        cached = None  # self._data_index.get(index)
        if cached is not None:
            model, col_index = cached
        else:
            model = None
            for m in self.data:
                cols = m.get_column_count()
                if real_index < cols:
                    model = m
                    break
                else:
                    real_index -= cols
            self._data_index[index] = (model, real_index)

        return model, real_index

    def _data_changed(self):
        """ Called when the data trait is changed.

        Since this is called when our underlying models change, the cached
        results of the column lookups is wrong and needs to be invalidated.
        """

        self._data_index.clear()

    def _data_items_changed(self):
        """ Called when the members of the data trait have changed.

        Since this is called when our underlying model change, the cached
        results of the column lookups is wrong and needs to be invalidated.
        """
        self._data_index.clear()
Пример #12
0
class PyContext(Context, Referenceable):
    """ A naming context for a Python namespace. """

    #### 'Context' interface ##################################################

    # The naming environment in effect for this context.
    environment = Dict(ENVIRONMENT)

    #### 'PyContext' interface ################################################

    # The Python namespace that we represent.
    namespace = Any

    # If the namespace is actual a Python object that has a '__dict__'
    # attribute, then this will be that object (the namespace will be the
    # object's '__dict__'.
    obj = Any

    #### 'Referenceable' interface ############################################

    # The object's reference suitable for binding in a naming context.
    reference = Property(Instance(Reference))

    ###########################################################################
    # 'object' interface.
    ###########################################################################

    def __init__(self, **traits):
        """ Creates a new context. """

        # Base class constructor.
        super(PyContext, self).__init__(**traits)

        if type(self.namespace) is not dict:
            if hasattr(self.namespace, '__dict__'):
                self.obj = self.namespace
                self.namespace = self.namespace.__dict__

            else:
                raise ValueError('Need a dictionary or a __dict__ attribute')

        return

    ###########################################################################
    # 'Referenceable' interface.
    ###########################################################################

    #### Properties ###########################################################

    def _get_reference(self):
        """ Returns a reference to this object suitable for binding. """

        reference = Reference(
            class_name=self.__class__.__name__,
            addresses=[Address(type='py_context', content=self.namespace)])

        return reference

    ###########################################################################
    # Protected 'Context' interface.
    ###########################################################################

    def _is_bound(self, name):
        """ Is a name bound in this context? """

        return name in self.namespace

    def _lookup(self, name):
        """ Looks up a name in this context. """

        obj = self.namespace[name]

        return naming_manager.get_object_instance(obj, name, self)

    def _bind(self, name, obj):
        """ Binds a name to an object in this context. """

        state = naming_manager.get_state_to_bind(obj, name, self)
        self.namespace[name] = state

        # Trait event notification.
        # An "added" event is fired by the bind method of the base calss (which calls
        # this one), so we don't need to do the changed here (which would be the wrong
        # thing anyway) -- LGV
        #
        # self.trait_property_changed('context_changed', None, None)

        return

    def _rebind(self, name, obj):
        """ Rebinds a name to a object in this context. """

        self._bind(name, obj)

        return

    def _unbind(self, name):
        """ Unbinds a name from this context. """

        del self.namespace[name]

        # Trait event notification.
        self.trait_property_changed('context_changed', None, None)

        return

    def _rename(self, old_name, new_name):
        """ Renames an object in this context. """

        state = self.namespace[old_name]

        # Bind the new name.
        self.namespace[new_name] = state

        # Unbind the old one.
        del self.namespace[old_name]

        # Trait event notification.
        self.context_changed = True

        return

    def _create_subcontext(self, name):
        """ Creates a sub-context of this context. """

        sub = self._context_factory(name, {})
        self.namespace[name] = sub

        # Trait event notification.
        self.trait_property_changed('context_changed', None, None)

        return sub

    def _destroy_subcontext(self, name):
        """ Destroys a sub-context of this context. """

        del self.namespace[name]

        # Trait event notification.
        self.trait_property_changed('context_changed', None, None)

        return

    def _list_bindings(self):
        """ Lists the bindings in this context. """

        bindings = []
        for name, value in list(self.namespace.items()):
            bindings.append(
                Binding(name=name, obj=self._lookup(name), context=self))
        return bindings

    def _list_names(self):
        """ Lists the names bound in this context. """

        return list(self.namespace.keys())

    ###########################################################################
    # Private interface.
    ###########################################################################

    def _context_factory(self, name, namespace):
        """ Create a sub-context. """

        return self.__class__(namespace=namespace)
Пример #13
0
class ConfirmationDialog(MConfirmationDialog, Dialog):
    """ The toolkit specific implementation of a ConfirmationDialog.  See the
    IConfirmationDialog interface for the API documentation.
    """

    # 'IConfirmationDialog' interface -------------------------------------#

    cancel = Bool(False)

    default = Enum(NO, YES, CANCEL)

    image = Image()

    message = Str()

    informative = Str()

    detail = Str()

    no_label = Str()

    yes_label = Str()

    # If we create custom buttons with the various roles, then we need to
    # keep track of the buttons so we can see what the user clicked.  It's
    # not correct nor sufficient to check the return result from QMessageBox.exec_().
    # (As of Qt 4.5.1, even clicking a button with the YesRole would lead to
    # exec_() returning QDialog.DialogCode.Rejected.
    _button_result_map = Dict()

    # ------------------------------------------------------------------------
    # Protected 'IDialog' interface.
    # ------------------------------------------------------------------------

    def _create_contents(self, parent):
        # In PyQt this is a canned dialog.
        pass

    # ------------------------------------------------------------------------
    # Protected 'IWidget' interface.
    # ------------------------------------------------------------------------

    def _create_control(self, parent):
        dlg = QtGui.QMessageBox(parent)

        dlg.setWindowTitle(self.title)
        dlg.setText(self.message)
        dlg.setInformativeText(self.informative)
        dlg.setDetailedText(self.detail)

        if self.size != (-1, -1):
            dlg.resize(*self.size)

        if self.position != (-1, -1):
            dlg.move(*self.position)

        if self.image is None:
            dlg.setIcon(QtGui.QMessageBox.Icon.Warning)
        else:
            dlg.setIconPixmap(self.image.create_image())

        # The 'Yes' button.
        if self.yes_label:
            btn = dlg.addButton(self.yes_label,
                                QtGui.QMessageBox.ButtonRole.YesRole)
        else:
            btn = dlg.addButton(QtGui.QMessageBox.StandardButton.Yes)
        self._button_result_map[btn] = YES

        if self.default == YES:
            dlg.setDefaultButton(btn)

        # The 'No' button.
        if self.no_label:
            btn = dlg.addButton(self.no_label,
                                QtGui.QMessageBox.ButtonRole.NoRole)
        else:
            btn = dlg.addButton(QtGui.QMessageBox.StandardButton.No)
        self._button_result_map[btn] = NO

        if self.default == NO:
            dlg.setDefaultButton(btn)

        # The 'Cancel' button.
        if self.cancel:
            if self.cancel_label:
                btn = dlg.addButton(self.cancel_label,
                                    QtGui.QMessageBox.ButtonRole.RejectRole)
            else:
                btn = dlg.addButton(QtGui.QMessageBox.StandardButton.Cancel)

            self._button_result_map[btn] = CANCEL

            if self.default == CANCEL:
                dlg.setDefaultButton(btn)

        return dlg

    def _show_modal(self):
        self.control.setWindowModality(
            QtCore.Qt.WindowModality.ApplicationModal)
        retval = self.control.exec_()
        if self.control is None:
            # dialog window closed
            if self.cancel:
                # if cancel is available, close is Cancel
                return CANCEL
            return self.default
        clicked_button = self.control.clickedButton()
        if clicked_button in self._button_result_map:
            retval = self._button_result_map[clicked_button]
        else:
            retval = _RESULT_MAP[retval]
        return retval
Пример #14
0
class TupleIndexManager(AbstractIndexManager):

    #: A dictionary that maps tuples to the canonical version of the tuple.
    _cache = Dict(Tuple, Tuple, {Root: Root}, can_reset=True)

    #: A dictionary that maps ids to the canonical version of the tuple.
    _id_cache = Dict(Int, Tuple, {0: Root}, can_reset=True)

    def create_index(self, parent, row):
        """ Given a parent index and a row number, create an index.

        Parameters
        ----------
        parent : index object
            The parent index object.
        row : non-negative int
            The position of the resulting index in the parent's children.

        Returns
        -------
        index : index object
            The resulting opaque index object.

        Raises
        ------
        IndexError
            Negative row values raise an IndexError exception.
        """
        if row < 0:
            raise IndexError("Row must be non-negative.  Got {}".format(row))

        index = (parent, row)
        canonical_index = self._cache.setdefault(index, index)
        self._id_cache[self.id(canonical_index)] = canonical_index
        return canonical_index

    def get_parent_and_row(self, index):
        """ Given an index object, return the parent index and row.

        Parameters
        ----------
        index : index object
            The opaque index object.

        Returns
        -------
        parent : index object
            The parent index object.
        row : int
            The position of the resuling index in the parent's children.

        Raises
        ------
        IndexError
            If the Root object is passed as the index, this method will
            raise an IndexError, as it has no parent.
        """
        if index == Root:
            raise IndexError("Root index has no parent.")
        return index

    def from_id(self, id):
        """ Given an integer id, return the corresponding index.

        Parameters
        ----------
        id : int
            An integer object id value.

        Returns
        -------
        index : index object
            The persistent index object associated with this id.
        """
        return self._id_cache[id]

    def id(self, index):
        """ Given an index, return the corresponding id.

        Parameters
        ----------
        index : index object
            The persistent index object.

        Returns
        -------
        id : int
            The associated integer object id value.
        """
        if index == Root:
            return 0
        canonical_index = self._cache.setdefault(index, index)
        return id(canonical_index)
Пример #15
0
class BeadCalibrationOp(HasStrictTraits):
    """
    Calibrate arbitrary channels to molecules-of-fluorophore using fluorescent
    beads (eg, the Spherotech RCP-30-5A rainbow beads.)
    
    To use, set the `beads_file` property to an FCS file containing the beads'
    events; specify which beads you ran by setting the `beads_type` property
    to match one of the values of BeadCalibrationOp.BEADS; and set the
    `units` dict to which channels you want calibrated and in which units.
    Then, call `estimate()` and check the peak-finding with 
    `default_view().plot()`.  If the peak-finding is wacky, try adjusting
    `bead_peak_quantile` and `bead_brightness_threshold`.  When the peaks are
    successfully identified, call apply() on your experimental data set. 
    
    If you can't make the peak finding work, please submit a bug report!
    
    This procedure works best when the beads file is very clean data.  It does
    not do its own gating (maybe a future addition?)  In the meantime, 
    I recommend gating the *acquisition* on the FSC/SSC channels in order
    to get rid of debris, cells, and other noise.
    
    Finally, because you can't have a negative number of fluorescent molecules
    (MEFLs, etc) (as well as for math reasons), this module filters out
    negative values.
    
    Attributes
    ----------
    name : Str
        The operation name (for UI representation.)

    units : Dict(Str, Str)
        A dictionary specifying the channels you want calibrated (keys) and
        the units you want them calibrated in (values).  The units must be
        keys of the `beads` attribute.       
        
    beads_file : File
        A file containing the FCS events from the beads.  Must be set to use
        `estimate()`.  This isn't persisted by `pickle()`.

    beads : Dict(Str, List(Float))
        The beads' characteristics.  Keys are calibrated units (ie, MEFL or
        MEAP) and values are ordered lists of known fluorophore levels.  Common
        values for this dict are included in BeadCalibrationOp.BEADS.
        Must be set to use `estimate()`.
        
    bead_peak_quantile : Int
        The quantile threshold used to choose bead peaks.  Default == 80.
        Must be set to use `estimate()`.
        
    bead_brightness_threshold : Float
        How bright must a bead peak be to be considered?  Default == 100.
        Must be set to use `estimate()`.
        
    bead_brightness_cutoff : Float
        If a bead peak is above this, then don't consider it.  Takes care of
        clipping saturated detection.  Defaults to 70% of the detector range.
        
    Notes
    -----
    The peak finding is rather sophisticated.  
    
    For each channel, a 256-bin histogram is computed on the log-transformed
    bead data, and then the histogram is smoothed with a Savitzky-Golay 
    filter (with a window length of 5 and a polynomial order of 1).  
    
    Next, a wavelet-based peak-finding algorithm is used: it convolves the
    smoothed histogram with a series of wavelets and looks for relative 
    maxima at various length-scales.  The parameters of the smoothing 
    algorithm were arrived at empircally, using beads collected at a wide 
    range of PMT voltages.
    
    Finally, the peaks are filtered by height (the histogram bin has a quantile
    greater than `bead_peak_quantile`) and intensity (brighter than 
    `bead_brightness_threshold`).
    
    How to convert from a series of peaks to mean equivalent fluorochrome?
    If there's one peak, we assume that it's the brightest peak.  If there
    are two peaks, we assume they're the brightest two.  If there are n >=3
    peaks, we check all the contiguous n-subsets of the bead intensities
    and find the one whose linear regression (in log space!) has the smallest
    norm (square-root sum-of-squared-residuals.)
    
    There's a slight subtlety in the fact that we're performing the linear
    regression in log-space: if the relationship in log10-space is Y=aX + b,
    then the same relationship in linear space is x = 10**X, y = 10**y, and
    y = (10**b) * (x ** a).
    
    One more thing.  Because the beads are (log) evenly spaced across all
    the channels, we can directly compute the fluorophore equivalent in channels
    where we wouldn't usually measure that fluorophore: for example, you can
    compute MEFL (mean equivalent fluorosceine) in the PE-Texas Red channel,
    because the bead peak pattern is the same in the PE-Texas Red channel
    as it would be in the FITC channel.
    
    Examples
    --------
    >>> bead_op = flow.BeadCalibrationOp()
    >>> bead_op.beads = flow.BeadCalibrationOp.BEADS["Spherotech RCP-30-5A Lot AA01-AA04, AB01, AB02, AC01, GAA01-R"]
    >>> bead_op.units = {"Pacific Blue-A" : "MEFL",
                         "FITC-A" : "MEFL",
                         "PE-Tx-Red-YG-A" : "MEFL"}
    >>>
    >>> bead_op.beads_file = "beads.fcs"
    >>> bead_op.estimate(ex3)
    >>>
    >>> bead_op.default_view().plot(ex3)  
    >>> # check the plot!
    >>>
    >>> ex4 = bead_op.apply(ex3)  
    """
    
    # traits
    id = Constant('edu.mit.synbio.cytoflow.operations.beads_calibrate')
    friendly_id = Constant("Bead Calibration")
    
    name = Constant("Bead Calibration")
    units = Dict(Str, Str)
    
    beads_file = File(exists = True)
    bead_peak_quantile = Int(80)

    bead_brightness_threshold = Float(100)
    bead_brightness_cutoff = Float(Undefined)
    # TODO - bead_brightness_threshold should probably be different depending
    # on the data range of the input.
    
    beads = Dict(Str, List(Float))

    _calibration_functions = Dict(Str, Python, transient = True)
    _peaks = Dict(Str, Python, transient = True)
    _mefs = Dict(Str, Python, transient = True)

    def estimate(self, experiment, subset = None): 
        """
        Estimate the calibration coefficients from the beads file.
        """
        if not experiment:
            raise util.CytoflowOpError("No experiment specified")
        
        if not self.beads_file:
            raise util.CytoflowOpError("No beads file specified")

        if not set(self.units.keys()) <= set(experiment.channels):
            raise util.CytoflowOpError("Specified channels that weren't found in "
                                  "the experiment.")
            
        if not set(self.units.values()) <= set(self.beads.keys()):
            raise util.CytoflowOpError("Units don't match beads.")
                        
        # make a little Experiment
        check_tube(self.beads_file, experiment)
        beads_exp = ImportOp(tubes = [Tube(file = self.beads_file)],
                             name_metadata = experiment.metadata['name_metadata']).apply()
        
        channels = self.units.keys()

        for channel in channels:
            data = beads_exp.data[channel]
            
            # TODO - this assumes the data is on a linear scale.  check it!
            data_range = experiment.metadata[channel]['range']

            if self.bead_brightness_cutoff is Undefined:
                cutoff = 0.7 * data_range
            else:
                cutoff = self.bead_brightness_cutoff
                                            
            # bin the data on a log scale

            hist_bins = np.logspace(1, math.log(data_range, 2), num = 256, base = 2)
            hist = np.histogram(data, bins = hist_bins)
            
            # mask off-scale values
            hist[0][0] = 0
            hist[0][-1] = 0
            
            # smooth it with a Savitzky-Golay filter
            hist_smooth = scipy.signal.savgol_filter(hist[0], 5, 1)
            
            # find peaks
            peak_bins = scipy.signal.find_peaks_cwt(hist_smooth, 
                                                    widths = np.arange(3, 20),
                                                    max_distances = np.arange(3, 20) / 2)
            
            # filter by height and intensity
            peak_threshold = np.percentile(hist_smooth, self.bead_peak_quantile)
            peak_bins_filtered = \
                [x for x in peak_bins if hist_smooth[x] > peak_threshold 
                 and hist[1][x] > self.bead_brightness_threshold
                 and hist[1][x] < cutoff]
            
            peaks = [hist_bins[x] for x in peak_bins_filtered]
            
            mef_unit = self.units[channel]
            
            if not mef_unit in self.beads:
                raise util.CytoflowOpError("Invalid unit {0} specified for channel {1}".format(mef_unit, channel))
            
            # "mean equivalent fluorochrome"
            mef = self.beads[mef_unit]
            
            if len(peaks) == 0:
                raise util.CytoflowOpError("Didn't find any peaks; check the diagnostic plot")
            elif len(peaks) > len(self.beads):
                raise util.CytoflowOpError("Found too many peaks; check the diagnostic plot")
            elif len(peaks) == 1:
                # if we only have one peak, assume it's the brightest peak
                a = mef[-1] / peaks[0]
                self._peaks[channel] = peaks
                self._mefs[channel] = [mef[-1]]
                self._calibration_functions[channel] = lambda x, a=a: a * x
            elif len(peaks) == 2:
                # if we have only two peaks, assume they're the brightest two
                self._peaks[channel] = peaks
                self._mefs[channel] = [mef[-1], mef[-2]]
                a = (mef[-1] - mef[-2]) / (peaks[1] - peaks[0])
                self._calibration_functions[channel] = lambda x, a=a: a * x
            else:
                # if there are n > 2 peaks, check all the contiguous n-subsets
                # of mef for the one whose linear regression with the peaks
                # has the smallest (norm) sum-of-residuals.
                
                # do it in log10 space because otherwise the brightest peaks
                # have an outsized influence.
                
                best_resid = np.inf
                for start, end in [(x, x+len(peaks)) for x in range(len(mef) - len(peaks) + 1)]:
                    mef_subset = mef[start:end]
                    
                    # linear regression of the peak locations against mef subset
                    lr = np.polyfit(np.log10(peaks), 
                                    np.log10(mef_subset), 
                                    deg = 1, 
                                    full = True)
                    
                    resid = lr[1][0]
                    if resid < best_resid:
                        best_lr = lr[0]
                        best_resid = resid
                        self._peaks[channel] = peaks
                        self._mefs[channel] = mef_subset
                        
                
                # remember, these (linear) coefficients came from logspace, so 
                # if the relationship in log10 space is Y = aX + b, then in
                # linear space the relationship is x = 10**X, y = 10**Y,
                # and y = (10**b) * x ^ a
                
                # also remember that the result of np.polyfit is a list of
                # coefficients with the highest power first!  so if we
                # solve y=ax + b, coeff #0 is a and coeff #1 is b
                
                a = best_lr[0]
                b = 10 ** best_lr[1]
                self._calibration_functions[channel] = \
                    lambda x, a=a, b=b: b * np.power(x, a)

    def apply(self, experiment):
        """Applies the bleedthrough correction to an experiment.
        
        Parameters
        ----------
        old_experiment : Experiment
            the experiment to which this op is applied
            
        Returns
        -------
            a new experiment calibrated in physical units.
        """
        if not experiment:
            raise util.CytoflowOpError("No experiment specified")
        
        channels = self.units.keys()

        if not self.units:
            raise util.CytoflowOpError("No channels to calibrate.")
        
        if not self._calibration_functions:
            raise util.CytoflowOpError("Calibration not found. "
                                  "Did you forget to call estimate()?")
        
        if not set(channels) <= set(experiment.channels):
            raise util.CytoflowOpError("Module units don't match experiment channels")
                
        if set(channels) != set(self._calibration_functions.keys()):
            raise util.CytoflowOpError("Calibration doesn't match units. "
                                  "Did you forget to call estimate()?")

        # two things.  first, you can't raise a negative value to a non-integer
        # power.  second, negative physical units don't make sense -- how can
        # you have the equivalent of -5 molecules of fluoresceine?  so,
        # we filter out negative values here.

        new_experiment = experiment.clone()
        
        for channel in channels:
            new_experiment.data = \
                new_experiment.data[new_experiment.data[channel] > 0]
                
        new_experiment.data.reset_index(drop = True, inplace = True)
        
        for channel in channels:
            calibration_fn = self._calibration_functions[channel]
            
            new_experiment[channel] = calibration_fn(new_experiment[channel])
            new_experiment.metadata[channel]['bead_calibration_fn'] = calibration_fn
            new_experiment.metadata[channel]['units'] = self.units[channel]
            if 'range' in experiment.metadata[channel]:
                new_experiment.metadata[channel]['range'] = calibration_fn(experiment.metadata[channel]['range'])
            
        new_experiment.history.append(self.clone_traits(transient = lambda t: True)) 
        return new_experiment
    
    def default_view(self, **kwargs):
        """
        Returns a diagnostic plot to see if the bleedthrough spline estimation
        is working.
        
        Returns
        -------
            IView : An IView, call plot() to see the diagnostic plots
        """

        return BeadCalibrationDiagnostic(op = self, **kwargs)
    
    BEADS = {
             # from http://www.spherotech.com/RCP-30-5a%20%20rev%20H%20ML%20071712.xls
             "Spherotech RCP-30-5A Lot AG01, AF02, AD04 and AAE01" :
                { "MECSB" : [216, 464, 1232, 2940, 7669, 19812, 35474],
                  "MEBFP" : [861, 1997, 5776, 15233, 45389, 152562, 396759],
                  "MEFL" :  [792, 2079, 6588, 16471, 47497, 137049, 271647],
                  "MEPE" :  [531, 1504, 4819, 12506, 36159, 109588, 250892],
                  "MEPTR" : [233, 669, 2179, 5929, 18219, 63944, 188785],
                  "MECY" : [1614, 4035, 12025, 31896, 95682, 353225, 1077421],
                  "MEPCY7" : [14916, 42336, 153840, 494263],
                  "MEAP" :  [373, 1079, 3633, 9896, 28189, 79831, 151008],
                  "MEAPCY7" : [2864, 7644, 19081, 37258]},
             # from http://www.spherotech.com/RCP-30-5a%20%20rev%20G.2.xls
             "Spherotech RCP-30-5A Lot AA01-AA04, AB01, AB02, AC01, GAA01-R":
                { "MECSB" : [179, 400, 993, 3203, 6083, 17777, 36331],
                  "MEBFP" : [700, 1705, 4262, 17546, 35669, 133387, 412089],
                  "MEFL" :  [692, 2192, 6028, 17493, 35674, 126907, 290983],
                  "MEPE" :  [505, 1777, 4974, 13118, 26757, 94930, 250470],
                  "MEPTR" : [207, 750, 2198, 6063, 12887, 51686, 170219],
                  "MECY" :  [1437, 4693, 12901, 36837, 76621, 261671, 1069858],
                  "MEPCY7" : [32907, 107787, 503797],
                  "MEAP" :  [587, 2433, 6720, 17962, 30866, 51704, 146080],
                  "MEAPCY7" : [718, 1920, 5133, 9324, 14210, 26735]}}
Пример #16
0
class Dimensions(HasTraits):
    """The dimensions of a physical quantity.

    This is essentially a thin wrapper around a dictionary which we perform
    certain operations on.

    Example
    -------
    >>> m = Dimensions({'mass': 1.0})
    >>> a = Dimensions({'length': 1.0, 'time': -2.0})
    >>> f = Dimensions({'length': 1.0, 'mass': 1.0, 'time': -2.0})
    >>> f == m*a
    True
    >>> f.expansion
    "length*mass*time**-2.0"
    """

    # a dictionary holding dimension names and quantities
    # this should be frozen if you want to hash - don't change it
    dimension_dict = Dict(Str, Float)

    # the quantity type as an expression in powers of base dimensions
    expansion = Property(String, depends_on='dimension_dict')

    def __init__(self, dimension_dict, **kwargs):
        dimension_dict = {k: v for k, v in dimension_dict.items() if v}
        super(
            self.__class__,
            self).__init__(
            dimension_dict=dimension_dict,
            **kwargs)

    @classmethod
    def from_expansion(cls, expansion):
        """Create a Dimension class instance from an expansion string

        This is a fairly simplistic parser - no parens, division, etc.

        Parameters
        ----------
        expansion : string
            an expansion of the dimensions (eg. mass*length**-3.0)
        """
        terms = expansion.split("*")
        dimension_dict = {}
        try:
            while terms:
                dim = terms.pop(0)
                if terms[0] == "":
                    terms.pop(0)
                    power = float(terms.pop(0))
                    dimension_dict[dim] = dimension_dict.get(dim, 0) + power
        except:
            raise InvalidExpansionError(expansion)
        return cls(dimension_dict)

    @cached_property
    def _get_expansion(self):
        if self.dimension_dict:
            return format_expansion(self.dimension_dict)
        else:
            return "dimensionless"

    def __repr__(self):
        return "Dimensions(%s)" % repr(self.dimension_dict)

    def __str__(self):
        return self.expansion

    def __eq__(self, other):
        return isinstance(other, self.__class__) \
            and self.dimension_dict == other.dimension_dict

    def __hash__(self):
        return hash(tuple(item for item in self.dimension_dict.items()))

    def __mul__(self, other):
        if isinstance(other, Dimensions):
            return Dimensions(dict_add(self.dimension_dict,
                                       other.dimension_dict))
        else:
            raise NotImplementedError

    def __div__(self, other):
        return type(self).__truediv__(self, other)

    def __truediv__(self, other):
        if isinstance(other, Dimensions):
            return Dimensions(dict_sub(self.dimension_dict,
                                       other.dimension_dict))
        else:
            raise NotImplementedError

    def __pow__(self, other):
        if isinstance(other, (float,) + six.integer_types):
            return Dimensions(dict_mul(self.dimension_dict, other))
        else:
            raise NotImplementedError
Пример #17
0
class BeadCalibrationPluginOp(PluginOpMixin, BeadCalibrationOp):
    handler_factory = Callable(BeadCalibrationHandler)

    beads_name = Str(estimate=True)
    beads = Dict(Str, List(Float), transient=True)

    beads_file = File(filter=["*.fcs"], estimate=True)
    units_list = List(_Unit, estimate=True)
    units = Dict(Str, Str, transient=True)

    bead_peak_quantile = CInt(80, estimate=True)
    bead_brightness_threshold = CFloat(100.0, estimate=True)
    bead_brightness_cutoff = util.CFloatOrNone(None, estimate=True)

    @on_trait_change('units_list_items,units_list.+', post_init=True)
    def _controls_changed(self, obj, name, old, new):
        self.changed = (Changed.ESTIMATE, ('units_list', self.units_list))

    def default_view(self, **kwargs):
        return BeadCalibrationPluginView(op=self, **kwargs)

    def apply(self, experiment):

        if not self.beads_name:
            raise util.CytoflowOpError(
                "Specify which beads to calibrate with.")

        for i, unit_i in enumerate(self.units_list):
            for j, unit_j in enumerate(self.units_list):
                if unit_i.channel == unit_j.channel and i != j:
                    raise util.CytoflowOpError(
                        "Channel {0} is included more than once".format(
                            unit_i.channel))

        self.units = {}
        for unit in self.units_list:
            self.units[unit.channel] = unit.unit

        self.beads = self.BEADS[self.beads_name]
        return BeadCalibrationOp.apply(self, experiment)

    def estimate(self, experiment):
        if not self.beads_name:
            raise util.CytoflowOpError(
                "Specify which beads to calibrate with.")

        for i, unit_i in enumerate(self.units_list):
            for j, unit_j in enumerate(self.units_list):
                if unit_i.channel == unit_j.channel and i != j:
                    raise util.CytoflowOpError(
                        "Channel {0} is included more than once".format(
                            unit_i.channel))

        self.units = {}
        for unit in self.units_list:
            self.units[unit.channel] = unit.unit

        self.beads = self.BEADS[self.beads_name]
        try:
            BeadCalibrationOp.estimate(self, experiment)
        except:
            raise
        finally:
            self.changed = (Changed.ESTIMATE_RESULT, self)

    def should_clear_estimate(self, changed, payload):
        if changed == Changed.ESTIMATE:
            return True

        return False

    def clear_estimate(self):
        self._calibration_functions.clear()
        self._peaks.clear()
        self._mefs.clear()
        self._histograms.clear()
        self.changed = (Changed.ESTIMATE_RESULT, self)

    def get_notebook_code(self, idx):
        op = BeadCalibrationOp()
        op.copy_traits(self, op.copyable_trait_names())

        for unit in self.units_list:
            op.units[unit.channel] = unit.unit

        op.beads = self.BEADS[self.beads_name]

        return dedent("""
        # Beads: {beads}
        op_{idx} = {repr}
        
        op_{idx}.estimate(ex_{prev_idx})
        ex_{idx} = op_{idx}.apply(ex_{prev_idx})
        """.format(beads=self.beads_name,
                   repr=repr(op),
                   idx=idx,
                   prev_idx=idx - 1))
Пример #18
0
class Experiment(HasStrictTraits):
    """
    An Experiment manages all the data and metadata for a flow experiment.
    
    An :class:`Experiment` is the central data struture in :mod:`cytoflow`: it 
    wraps a :class:`pandas.DataFrame` containing all the data from a flow 
    experiment. Each row in the table is an event.  Each column is either a 
    measurement from one of the detectors (or a "derived" measurement such as 
    a transformed value or a ratio), or a piece of metadata associated with 
    that event: which tube it came from, what the experimental conditions for 
    that tube were, gate membership, etc.  The :class:`Experiment` object lets 
    you:
    
      - Add additional metadata to define subpopulations
      - Get events that match a particular metadata signature.
      
    Additionally, the :class:`Experiment` object manages channel- and 
    experiment-level metadata in the :attr:`metadata` attribute, which is a 
    dictionary.  This allows the rest of the :mod:`cytoflow` package to track 
    and enforce other constraints that are important in doing quantitative 
    flow cytometry: for example, every tube must be collected with the same 
    channel parameters (such as PMT voltage.)

    .. note:: 
    
        :class:`Experiment` is not responsible for enforcing the constraints; 
        :class:`.ImportOp` and the other modules are.
    
    Attributes
    ----------

    data : pandas.DataFrame
        All the events and metadata represented by this experiment.  Each event
        is a row; each column is either a measured channel (eg. a fluorescence
        measurement), a derived channel (eg. the ratio between two channels), 
        or a piece of metadata.  Metadata can be either experimental conditions
        (eg. induction level, timepoint) or added by operations (eg. gate 
        membership).
        
    metadata : Dict(Str : Dict(Str : Any)
        Each column in :attr:`data` has an entry in :attr:`metadata` whose key 
        is the column name and whose value is a dict of column-specific 
        metadata.  Metadata is added by operations, and is occasionally useful
        if modules are expected to work together.  See individual operations' 
        documentation for a list of the metadata that operation adds.  The only 
        "required" metadata is ``type``, which can be ``channel`` (if the 
        column is a measured channel, or derived from one) or ``condition`` 
        (if the column is an experimental condition, gate membership, etc.)
        
        .. warning::
        
            There may also be experiment-wide entries in :attr:`metadata` that
            are *not* columns in :attr:`data`!
    
    history : List(IOperation)
        The :class:`.IOperation` operations that have been applied to the raw 
        data to result in this :class:`Experiment`.
        
    statistics : Dict((Str, Str) : pandas.Series)
        The statistics and parameters computed by models that were fit to the 
        data.  The key is an ``(Str, Str)`` tuple, where the first ``Str`` is 
        the name of the operation that supplied the statistic, and the second 
        ``Str`` is the name of the statistic. The value is a multi-indexed 
        :class:`pandas.Series`: each level of the index is a facet, and each 
        combination of indices is a subset for which the statistic was computed.
        The values of the series, of course, are the values of the computed 
        parameters or statistics for each subset.
    
    channels : List(String)
        The channels that this experiment tracks (read-only).
    
    conditions : Dict(String : pandas.Series)
        The experimental conditions and analysis groups (gate membership, etc) 
        that this experiment tracks.  The key is the name of the condition, and 
        the value is a :class:`pandas.Series` with that condition's possible 
        values. 

    Notes
    -----
    
    The OOP programmer in me desperately wanted to subclass 
    :class:`pandas.DataFrame`, add some flow-specific stuff, and move on with 
    my life.  (I may still, with something like 
    https://github.com/dalejung/pandas-composition).  A few things get in the 
    way of directly subclassing :class:`pandas.DataFrame`:
    
     - First, to enable some of the delicious syntactic sugar for accessing
       its contents, :class:`pandas.DataFrame` redefines 
       :meth:`__getattribute__` and :meth:`__setattribute__`, and making it 
       recognize (and maintain across copies) additional attributes is an 
       unsupported (non-public) API feature and introduces other 
       subclassing weirdness.
    
     - Second, many of the operations (like appending!) don't happen in-place;
       they return copies instead.  It's cleaner to simply manage that copying
       ourselves instead of making the client deal with it.  We can pretend
       to operate on the data in-place.
       
    To maintain the ease of use, we'll override :meth:`__getitem__` and pass it 
    to the wrapped :class:`pandas.DataFrame`.  We'll do the same with some of 
    the more useful :class:`~pandas.DataFrame` API pieces (like :meth:`query`); 
    and of course, you can just get the data frame itself with 
    :attr:`Experiment.data`.
    
    Examples
    --------
    >>> import cytoflow as flow
    >>> tube1 = flow.Tube(file = 'cytoflow/tests/data/Plate01/RFP_Well_A3.fcs',
    ...                   conditions = {"Dox" : 10.0})
    >>> tube2 = flow.Tube(file='cytoflow/tests/data/Plate01/CFP_Well_A4.fcs',
    ...                   conditions = {"Dox" : 1.0})
    >>> 
    >>> import_op = flow.ImportOp(conditions = {"Dox" : "float"},
    ...                           tubes = [tube1, tube2])
    >>> 
    >>> ex = import_op.apply()
    >>> ex.data.shape
    (20000, 17)
    >>> ex.data.groupby(['Dox']).size()
    Dox
    1      10000
    10     10000
    dtype: int64

    """

    # this doesn't play nice with copy.copy() (used if, say, you copy
    # an Experiment with HasTraits.clone_traits()) -- instead, copy
    # a reference when clone_traits() is called, then replace it with
    # using pandas.DataFrame.copy(deep = False)
    data = Instance(pd.DataFrame, args=(), copy="ref")

    # potentially mutable.  deep copy required
    metadata = Dict(Str, Any, copy="deep")

    # statistics.  mutable, deep copy required
    statistics = Dict(Tuple(Str, Str), pd.Series, copy="deep")

    history = List(Any, copy="shallow")

    channels = Property(List)
    conditions = Property(Dict)

    def __getitem__(self, key):
        """Override __getitem__ so we can reference columns like ex.column"""
        return self.data.__getitem__(key)

    def __setitem__(self, key, value):
        """Override __setitem__ so we can assign columns like ex.column = ..."""
        if key in self.data:
            self.data.drop(key, axis='columns', inplace=True)
        return self.data.__setitem__(key, value)

    def __len__(self):
        """Return the length of the underlying pandas.DataFrame"""
        return len(self.data)

    def _get_channels(self):
        """Getter for the `channels` property"""
        return sorted(
            [x for x in self.data if self.metadata[x]['type'] == "channel"])

    def _get_conditions(self):
        """Getter for the `conditions` property"""
        return {
            x: pd.Series(self.data[x].unique().copy()).sort_values()
            for x in self.data if self.metadata[x]['type'] == "condition"
        }

    def subset(self, conditions, values):
        """
        Returns a subset of this experiment including only the events where
        each condition in ``condition`` equals the corresponding value in 
        ``values``.
        
        
        Parameters
        ----------
        conditions : Str or Tuple(Str)
            A condition or list of conditions
            
        values : Any or Tuple(Any)
            The value(s) of the condition(s)
            
        Returns
        -------
        Experiment
            A new :class:`Experiment` containing only the events specified in 
            ``conditions`` and ``values``.
            
        """

        if isinstance(conditions, str):
            c = conditions
            v = values
            if c not in self.conditions:
                raise util.CytoflowError("{} is not a condition".format(c))
            if v not in list(self.conditions[c]):
                raise util.CytoflowError(
                    "{} is not a value of condition {}".format(v, c))
        else:
            for c, v in zip(conditions, values):
                if c not in self.conditions:
                    raise util.CytoflowError("{} is not a condition".format(c))
                if v not in list(self.conditions[c]):
                    raise util.CytoflowError(
                        "{} is not a value of condition {}".format(v, c))

        g = self.data.groupby(conditions)

        ret = self.clone()
        ret.data = g.get_group(values)
        ret.data.reset_index(drop=True, inplace=True)

        return ret

    def query(self, expr, **kwargs):
        """
        Return an experiment whose data is a subset of this one where ``expr``
        evaluates to ``True``.

        This method "sanitizes" column names first, replacing characters that
        are not valid in a Python identifier with an underscore ``_``. So, the
        column name ``a column`` becomes ``a_column``, and can be queried with
        an ``a_column == True`` or such.
        
        Parameters
        ----------
        expr : string
            The expression to pass to :meth:`pandas.DataFrame.query`.  Must be
            a valid Python expression, something you could pass to :func:`eval`.
            
        **kwargs : dict
            Other named parameters to pass to :meth:`pandas.DataFrame.query`.
            
        Returns
        -------
        Experiment
            A new :class:`Experiment`, a clone of this one with the data 
            returned by :meth:`pandas.DataFrame.query()`
        """

        resolvers = {}
        for name, col in self.data.iteritems():
            new_name = util.sanitize_identifier(name)
            if new_name in resolvers:
                raise util.CytoflowError(
                    "Tried to sanitize column name {1} to "
                    "{2} but it already existed in the "
                    " DataFrame.".format(name, new_name))
            else:
                resolvers[new_name] = col

        ret = self.clone()
        ret.data = self.data.query(expr, resolvers=({}, resolvers), **kwargs)
        ret.data.reset_index(drop=True, inplace=True)

        if len(ret.data) == 0:
            raise util.CytoflowError("No events matched {}".format(expr))

        return ret

    def clone(self):
        """
        Create a copy of this :class:`Experiment.` :attr:`metadata` and 
        :attr:`statistics` are deep copies; :attr:`history` is a shallow copy; and
        .....
          
        """

        new_exp = self.clone_traits()
        new_exp.data = self.data.copy(deep=True)

        return new_exp

    def add_condition(self, name, dtype, data=None):
        """
        Add a new column of per-event metadata to this :class:`Experiment`.
        
        .. note::
            :meth:`add_condition` operates **in place.**
        
        There are two places to call `add_condition`.
        
          - As you're setting up a new :class:`Experiment`, call 
            :meth:`add_condition` with ``data`` set to ``None`` to specify the 
            conditions the new events will have.
          - If you compute some new per-event metadata on an existing 
            :class:`Experiment`, call :meth:`add_condition` to add it. 
        
        Parameters
        ----------
        name : String
            The name of the new column in :attr:`data`.  Must be a valid Python
            identifier: must start with ``[A-Za-z_]`` and contain only the 
            characters ``[A-Za-z0-9_]``.
        
        dtype : String
            The type of the new column in :attr:`data`.  Must be a string that
            :class:`pandas.Series` recognizes as a ``dtype``: common types are 
            ``category``, ``float``, ``int``, and ``bool``.
            
        data : pandas.Series (default = None)
            The :class:`pandas.Series` to add to :attr:`data`.  Must be the same
            length as :attr:`data`, and it must be convertable to a 
            :class:`pandas.Series` of type ``dtype``.  If ``None``, will add an
            empty column to the :class:`Experiment` ... but the 
            :class:`Experiment` must be empty to do so!
             
        Raises
        ------
        :class:`.CytoflowError`
            If the :class:`pandas.Series` passed in ``data`` isn't the same 
            length as :attr:`data`, or isn't convertable to type ``dtype``.          
            
        Examples
        --------
        >>> import cytoflow as flow
        >>> ex = flow.Experiment()
        >>> ex.add_condition("Time", "float")
        >>> ex.add_condition("Strain", "category")      
        
        """

        if name != util.sanitize_identifier(name):
            raise util.CytoflowError(
                "Name '{}' is not a valid Python identifier".format(name))

        if name in self.data:
            raise util.CytoflowError(
                "Already a column named {0} in self.data".format(name))

        if data is None and len(self) > 0:
            raise util.CytoflowError(
                "If data is None, self.data must be empty!")

        if data is not None and len(self) != len(data):
            raise util.CytoflowError(
                "data must be the same length as self.data")

        try:
            if data is not None:
                self.data[name] = data.astype(dtype, copy=True)
            else:
                self.data[name] = pd.Series(dtype=dtype)

        except (ValueError, TypeError) as exc:
            raise util.CytoflowError(
                "Had trouble converting data to type {0}".format(
                    dtype)) from exc

        self.metadata[name] = {}
        self.metadata[name]['type'] = "condition"

    def add_channel(self, name, data=None):
        """
        Add a new column of per-event data (as opposed to metadata) to this
        :class:`Experiment`: ie, something that was measured per cell, or 
        derived from per-cell measurements.    
          
          .. note::
          
            :meth:`add_channel` operates *in place*.
        
        Parameters
        ----------
        name : String
            The name of the new column to be added to :attr:`data`.
            
        data : pandas.Series
            The :class:`pandas.Series` to add to :attr:`data`.  Must be the same
            length as :attr:`data`, and it must be convertable to a 
            dtype of ``float64``.  If ``None``, will add an empty column to 
            the :class:`Experiment` ... but the :class:`Experiment` must be 
            empty to do so!
             
        Raises
        ------
        :exc:`.CytoflowError`
            If the :class:`pandas.Series` passed in ``data`` isn't the same length
            as :attr:`data`, or isn't convertable to a dtype ``float64``.          
            
        Examples
        --------
        >>> ex.add_channel("FSC_over_2", ex.data["FSC-A"] / 2.0) 
        
        """

        if name in self.data:
            raise util.CytoflowError(
                "Already a column named {0} in self.data".format(name))

        if data is None and len(self) > 0:
            raise util.CytoflowError(
                "If data is None, self.data must be empty!")

        if data is not None and len(self) != len(data):
            raise util.CytoflowError(
                "data must be the same length as self.data")

        try:
            if data is not None:
                self.data[name] = data.astype("float64", copy=True)
            else:
                self.data[name] = pd.Series(dtype="float64")

        except (ValueError, TypeError) as exc:
            raise util.CytoflowError(
                "Had trouble converting data to type \"float64\"") from exc

        self.metadata[name] = {}
        self.metadata[name]['type'] = "channel"

    def add_events(self, data, conditions):
        """
        Add new events to this :class:`Experiment`.
        
        Each new event in ``data`` is appended to :attr:`data`, and its 
        per-event metadata columns will be set with the values specified in 
        ``conditions``.  Thus, it is particularly useful for adding tubes of 
        data to new experiments, before additional per-event metadata is added 
        by gates, etc.
        
        .. note::
        
            *Every* column in :attr:`data` must be accounted for.  Each column 
            of type ``channel`` must appear in ``data``; each column of 
            metadata must have a key:value pair in ``conditions``.
        
        Parameters
        ----------
        tube : pandas.DataFrame
            A single tube or well's worth of data. Must be a DataFrame with
            the same columns as :attr:`channels`
        
        conditions : Dict(Str, Any)
            A dictionary of the tube's metadata.  The keys must match 
            :attr:`conditions`, and the values must be coercable to the
            relevant ``numpy`` dtype.
 
        Raises
        ------
        :exc:`.CytoflowError`
            :meth:`add_events` pukes if:
    
                - there are columns in ``data`` that aren't channels in the 
                  experiment, or vice versa. 
                - there are keys in ``conditions`` that aren't conditions in
                  the experiment, or vice versa.
                - there is metadata specified in ``conditions`` that can't be
                  converted to the corresponding metadata ``dtype``.
            
        Examples
        --------
        >>> import cytoflow as flow
        >>> import fcsparser
        >>> ex = flow.Experiment()
        >>> ex.add_condition("Time", "float")
        >>> ex.add_condition("Strain", "category")
        >>> tube1, _ = fcparser.parse('CFP_Well_A4.fcs')
        >>> tube2, _ = fcparser.parse('RFP_Well_A3.fcs')
        >>> ex.add_events(tube1, {"Time" : 1, "Strain" : "BL21"})
        >>> ex.add_events(tube2, {"Time" : 1, "Strain" : "Top10G"})
        
        """

        # make sure the new tube's channels match the rest of the
        # channels in the Experiment

        if len(self) > 0 and set(data.columns) != set(self.channels):
            raise util.CytoflowError("New events don't have the same channels")

        # check that the conditions for this tube exist in the experiment
        # already

        if( any(True for k in conditions if k not in self.conditions) or \
            any(True for k in self.conditions if k not in conditions) ):
            raise util.CytoflowError(
                "Metadata for this tube should be {}".format(
                    list(self.conditions.keys())))

        # add the conditions to tube's internal data frame.  specify the conditions
        # dtype using self.conditions.  check for errors as we do so.

        # take this chance to up-convert the float32s to float64.
        # this happened automatically in DataFrame.append(), below, but
        # only in certain cases.... :-/

        # TODO - the FCS standard says you can specify the precision.
        # check with int/float/double files!

        new_data = data.astype("float64", copy=True)

        for meta_name, meta_value in conditions.items():
            meta_type = self.conditions[meta_name].dtype

            if is_categorical_dtype(meta_type):
                meta_type = CategoricalDtype([meta_value])

            new_data[meta_name] = \
                pd.Series(data = [meta_value] * len(new_data),
                          index = new_data.index,
                          dtype = meta_type)

            # if we're categorical, merge the categories
            if is_categorical_dtype(meta_type) and meta_name in self.data:
                cats = set(self.data[meta_name].cat.categories) | set(
                    new_data[meta_name].cat.categories)
                cats = sorted(cats)
                self.data[meta_name] = self.data[meta_name].cat.set_categories(
                    cats)
                new_data[meta_name] = new_data[meta_name].cat.set_categories(
                    cats)

        self.data = self.data.append(new_data, ignore_index=True, sort=True)
        del new_data
Пример #19
0
class MEditorAreaPane(HasTraits):

    #### 'IEditorAreaPane' interface ##########################################

    active_editor = Instance(IEditor)
    editors = List(IEditor)
    file_drop_extensions = List(Str)
    file_dropped = Event(File)
    hide_tab_bar = Bool(False)

    #### Protected traits #####################################################

    _factory_map = Dict(Callable, List(Callable))

    ###########################################################################
    # 'IEditorAreaPane' interface.
    ###########################################################################

    def create_editor(self, obj, factory=None):
        """ Creates an editor for an object.
        """
        if factory is None:
            factory = self.get_factory(obj)

        if factory is not None:
            return factory(editor_area=self, obj=obj)

        return None

    def edit(self, obj, factory=None, use_existing=True):
        """ Edit an object.
        """
        if use_existing:
            # Is the object already being edited in the window?
            editor = self.get_editor(obj)
            if editor is not None:
                self.activate_editor(editor)
                return editor

        # If not, create an editor for it.
        editor = self.create_editor(obj, factory)
        if editor is None:
            logger.warn('Cannot create editor for obj %r', obj)

        else:
            self.add_editor(editor)
            self.activate_editor(editor)

        return editor

    def get_editor(self, obj):
        """ Returns the editor for an object.
        """
        for editor in self.editors:
            if editor.obj == obj:
                return editor
        return None

    def get_factory(self, obj):
        """ Returns an editor factory suitable for editing an object.
        """
        for factory, filters in self._factory_map.iteritems():
            for filter_ in filters:
                # FIXME: We should swallow exceptions, but silently?
                try:
                    if filter_(obj):
                        return factory
                except:
                    pass
        return None

    def register_factory(self, factory, filter):
        """ Registers a factory for creating editors.
        """
        self._factory_map.setdefault(factory, []).append(filter)

    def unregister_factory(self, factory):
        """ Unregisters a factory for creating editors.
        """
        if factory in self._factory_map:
            del self._factory_map[factory]
Пример #20
0
class BaseComponent(HasStrictTraits):
    """ The most base class of the Enaml component hierarchy.

    All declarative Enaml classes should inherit from this class. This 
    class is not meant to be instantiated directly.

    """
    #: A readonly property which returns the current instance of
    #: the component. This allows declarative Enaml components to
    #: access self according to the standard attribute scoping rules.
    self = Property

    #: The parent component of this component. It is stored as a weakref
    #: to mitigate issues with reference cycles. A top-level component's
    #: parent is None.
    parent = WeakRef('BaseComponent')

    #: The list of children for this component. This is a read-only
    #: lazy property that is computed based on the static list of
    #: _subcomponents and the items they return by calling their
    #: 'get_actual' method. This list should not be manipulated by
    #: user code.
    children = LazyProperty(
        List(Instance('BaseComponent')), 
        depends_on='_subcomponents:_actual_updated',
    )

    #: Whether the component has been initialized or not. This will be 
    #: set to True after all of the setup() steps defined here are 
    #: completed. It should not be changed afterwards. This can be used 
    #: to trigger certain actions that need to occur after the component 
    #: has been set up.
    initialized = Bool(False)

    #: An optional name to give to this component to assist in finding
    #: it in the tree. See the 'find_by_name' method.
    name = Str

    #: A reference to the toolkit that was used to create this object.
    toolkit = Instance(Toolkit)

    #: The private dictionary of expression objects that are bound to 
    #: attributes on this component. It should not be manipulated by
    #: user code. Rather, expressions should be bound by calling the 
    #: 'bind_expression' method.
    _expressions = Dict(Str, List(Instance(AbstractExpression)))

    #: The private list of virtual base classes that were used to 
    #: instantiate this component from Enaml source code. The 
    #: EnamlFactory class of the Enaml runtime will directly append
    #: to this list as necessary.
    _bases = List

    #: The private internal list of subcomponents for this component. 
    #: This list should not be manipulated by the user, and should not
    #: be changed after initialization. It can, however, be redefined
    #: by subclasses to limit the type or number of subcomponents.
    _subcomponents = List(Instance('BaseComponent'))

    #: A private event that should be emitted by a component when the 
    #: results of calling get_actual() will result in new values. 
    #: This event is listened to by the parent of subcomponents in order 
    #: to know when to rebuild its list of children. User code will not 
    #: typically interact with this event.
    _actual_updated = EnamlEvent

    #: The HasTraits class defines a class attribute 'set' which is
    #: a deprecated alias for the 'trait_set' method. The problem
    #: is that having that as an attribute interferes with the 
    #: ability of Enaml expressions to resolve the builtin 'set',
    #: since the dynamic attribute scoping takes precedence over
    #: builtins. This resets those ill-effects.
    set = Disallow

    #--------------------------------------------------------------------------
    # Special Methods
    #--------------------------------------------------------------------------
    def __repr__(self):
        """ An overridden repr which returns the repr of the factory 
        from which this component is derived, provided that it is not 
        simply a root constructor. Otherwise, it defaults to the super
        class' repr implementation.

        """
        # If there are any bases, the last one in the list will always 
        # be a constructor. We want to ignore that one and focus on the
        # repr of the virtual base class from which the component was 
        # derived in the Enaml source code.
        bases = self._bases
        if len(bases) >= 2:
            base = bases[0]
            return repr(base)
        return super(BaseComponent, self).__repr__()

    #--------------------------------------------------------------------------
    # Property Getters
    #--------------------------------------------------------------------------
    def _get_self(self):
        """ The property getter for the 'self' attribute.

        """
        return self
        
    def _get_children(self):
        """ The lazy property getter for the 'children' attribute.

        This property getter returns the flattened list of components
        returned by calling 'get_actual()' on each subcomponent.

        """
        return sum([c.get_actual() for c in self._subcomponents], [])
    
    #--------------------------------------------------------------------------
    # Component Manipulation
    #--------------------------------------------------------------------------
    def get_actual(self):
        """ Returns the list of BaseComponent instances which should be
        included as proper children of our parent. By default this 
        simply returns [self]. This method should be reimplemented by 
        subclasses which need to contribute different components to their
        parent's children.

        """
        return [self]
        
    def add_subcomponent(self, component):
        """ Adds the given component as a subcomponent of this object.
        By default, the subcomponent is added to an internal list of 
        subcomponents. This method may be overridden by subclasses to 
        filter or otherwise handle certain subcomponents differently.

        """
        component.parent = self
        self._subcomponents.append(component)
    
    #--------------------------------------------------------------------------
    # Setup Methods 
    #--------------------------------------------------------------------------
    def setup(self, parent=None):
        """ Run the setup process for the ui tree.

        The setup process is fairly complex and involves multiple steps.
        The complexity is required in order to ensure a consistent state
        of the component tree so that default values that are computed 
        from expressions have the necessary information available.

        The setup process is comprised of the following steps:
        
        1)  Abstract objects create their internal toolkit object
        2)  Abstract objects initialize their internal toolkit object
        3)  Bound expression values are explicitly applied
        4)  Abstract objects bind their event handlers
        5)  Abstract objects are added as listeners to the shell object
        6)  Visibility is initialized
        7)  Layout is initialized
        8)  A finalization pass is made
        9)  Nodes are marked as initialized
        
        Many of these setup methods are no-ops, but are defined on this
        BaseComponent for simplicity and continuity. Subclasses that
        need to partake in certain portions of the layout process 
        should re-implement the appropriate setup methods.

        Parameters
        ----------
        parent : native toolkit widget, optional
            If embedding this BaseComponent into a non-Enaml GUI, use 
            this to pass the appropriate toolkit widget that should be 
            the parent toolkit widget for this component.

        """
        self._setup_create_widgets(parent)
        self._setup_init_widgets()
        self._setup_eval_expressions()
        self._setup_bind_widgets()
        self._setup_listeners()
        self._setup_init_visibility()
        self._setup_init_layout()
        self._setup_finalize()
        self._setup_set_initialized()

    def _setup_create_widgets(self, parent):
        """ A setup method that, by default, is a no-op. Subclasses 
        that drive gui toolkit widgets should reimplement this method
        to create the underlying toolkit widget(s).

        """
        for child in self._subcomponents:
            child._setup_create_widgets(parent)

    def _setup_init_widgets(self):
        """ A setup method that, by default, is a no-op. Subclasses 
        that drive gui toolkit widgets should reimplement this method
        to initialize their internal toolkit widget(s).

        """
        for child in self._subcomponents:
            child._setup_init_widgets()

    def _setup_eval_expressions(self):
        """ A setup method that loops over all of bound expressions and
        performs a getattr for those attributes. This ensures that all
        bound attributes are initialized, even if they weren't implicitly
        initialized in any of the previous setup methods.

        """
        for name in self._expressions:
            getattr(self, name)
        for child in self._subcomponents:
            child._setup_eval_expressions()

    def _setup_bind_widgets(self):
        """ A setup method that, by default, is a no-op. Subclasses 
        that drive gui toolkit widgets should reimplement this method
        to bind any event handlers of their internal toolkit widget(s).

        """
        for child in self._subcomponents:
            child._setup_bind_widgets()

    def _setup_listeners(self):
        """ A setup method that, by default, is a no-op. Subclasses 
        that drive gui toolkit widgets should reimplement this method
        to setup an traits listeners necessary to drive their internal
        toolkit widget(s).

        """
        for child in self._subcomponents:
            child._setup_listeners()

    def _setup_init_visibility(self):
        """ A setup method that, by default, is a no-op. Subclasses 
        that drive gui toolkit widgets should reimplement this method
        to initialize the visibility of their widgets.

        """
        for child in self._subcomponents:
            child._setup_init_visibility()

    def _setup_init_layout(self):
        """ A setup method that, by default, is a no-op. Subclasses 
        that manage layout should reimplement this method to initialize
        their underlying layout.

        """
        for child in self._subcomponents:
            child._setup_init_layout()

    def _setup_finalize(self):
        """ A setup method that, by default, is a no-op. Subclasses
        that need to perform process after layout is initialized but
        before a node is marked as fully initialized should reimplement
        this method.

        """
        for child in self._subcomponents:
            child._setup_finalize()

    def _setup_set_initialized(self):
        """ A setup method which updates the initialized attribute of 
        the component to True. This is performed bottom-up.

        """
        for child in self._subcomponents:
            child._setup_set_initialized()
        self.initialized = True

    #--------------------------------------------------------------------------
    # Teardown Methods
    #--------------------------------------------------------------------------
    def destroy(self):
        """ Destroys the component by clearing the list of subcomponents
        and calling 'destroy' on all of the old subcomponents, then gets
        rid of all references to the subcomponents and bound expressions.
        Subclasses that need more control over destruction should 
        reimplement this method.

        """
        for child in self._subcomponents:
            child.destroy()
        del self._subcomponents[:]
        self._expressions.clear()

    #--------------------------------------------------------------------------
    # Layout Stubs
    #--------------------------------------------------------------------------
    def relayout(self):
        """ A method called when the layout of the component's children
        should be refreshed. By default, this method proxies the call up
        the hierarchy until an implementor is found. Any implementors 
        should ensure that the necessary operations take place immediately 
        and are complete before the method returns.

        """
        parent = self.parent
        if parent is not None:
            parent.relayout()

    def request_relayout(self):
        """ A method called when the layout of the component's children
        should be refreshed at some point in the future. By default, this 
        method proxies the call up the hierarchy until an implementor is 
        found. Any implementors should ensure that this method returns 
        immediately, and that relayout occurs at some point in the future.

        """
        parent = self.parent
        if parent is not None:
            parent.request_relayout()

    def refresh(self):
        """ A method called when the positioning of the component's 
        children should be refreshed. By default, this method proxies the 
        call up the hierarchy until an implementor is found. Implementors 
        should ensure that this method takes place immediately, and that
        the refresh is complete before the method returns. 

        Note: This method should perform less work than 'relayout' and 
            should typically only need to be called when the children 
            need to be repositioned, rather than have all of their layout 
            relationships recomputed.

        """
        parent = self.parent
        if parent is not None:
            parent.refresh()
        
    def request_refresh(self):
        """ A method called when the positioning of the component's 
        children should be refreshed at some point in the future. By 
        default, this method proxies the call up the hierarchy until an 
        implementor is found. Implementors should ensure that this method 
        returns immediately, and that the refresh is completed at some 
        time in the future.
        
        Note: This method should perform less work than 'relayout' and 
            should typically only need to be called when the children 
            need to be repositioned, rather than have all of their layout 
            relationships recomputed.

        """
        parent = self.parent
        if parent is not None:
            parent.request_refresh()
        
    def request_relayout_task(self, callback, *args, **kwargs):
        """ Schedule a callback to be executed, followed by a relayout. 
        By default, this method proxies the call up the hierarchy until 
        an implementor is found. Implementors should ensure that the
        callback is executed with given arguments at some point in the
        future and is followed by a relayout. It is suggested that 
        implementors collapse multiple calls to this method which
        results in a single relayout.

        """
        parent = self.parent
        if parent is not None:
            parent.request_relayout_task(callback, *args, **kwargs)
        
    def request_refresh_task(self, callback, *args, **kwargs):
        """ Schedule a callback to be executed, followed by a rerfresh. 
        By default, this method proxies the call up the hierarchy until 
        an implementor is found. Implementors should ensure that the
        callback is executed with given arguments at some point in the
        future and is followed by a relayout. It is suggested that 
        implementors collapse multiple calls to this method which
        results in a single refresh.

        """
        parent = self.parent
        if parent is not None:
            parent.request_refresh_task(callback, *args, **kwargs)

    #--------------------------------------------------------------------------
    # Bound Attribute Handling
    #--------------------------------------------------------------------------
    def add_attribute(self, name, attr_type=object, is_event=False):
        """ Adds an attribute to the base component with the given name
        and ensures that values assigned to this attribute are of a
        given type.

        If the object already has an attribute with the given name,
        an exception will be raised.

        Parameters
        ----------
        name : string
            The name of the attribute to add.
        
        attr_type : type-like object, optional
            An object that behaves like a type for the purposes of a
            call to isinstance. Defaults to object.
        
        is_event : bool, optional
            If True, the added attribute will behave like an event.
            Otherwise, it will behave like a normal attribute. The 
            default is False.

        """
        # Check to see if a trait is already defined. We don't use
        # hasattr here since that might prematurely trigger a trait
        # intialization. We allow overriding traits of type Disallow,
        # UserAttribute, and UserEvent. The first is a consequence of 
        # using HasStrictTraits, where non-existing attributes are 
        # manifested as a Disallow trait. The others allow a custom 
        # derived component to specialize the attribute and event types 
        # of the component from which it is deriving.
        curr = self.trait(name)
        if curr is not None:
            ttype = curr.trait_type
            allowed = (UserAttribute, UserEvent)
            if ttype is not Disallow and not isinstance(ttype, allowed):
                msg = ("Cannot add '%s' attribute. The '%s' attribute on "
                       "the %s object already exists.")
                raise TypeError(msg % (name, name, self))
            
        # At this point we know there are no non-overridable traits 
        # defined for the object, but it is possible that there are 
        # methods or other non-trait attributes using the given name. 
        # We could potentially check for those, but its probably more 
        # useful to allow for overriding such things from Enaml, so we 
        # just go ahead and add the attribute.
        try:
            if is_event:
                self.add_trait(name, UserEvent(attr_type))
            else:
                self.add_trait(name, UserAttribute(attr_type))
        except TypeError:
            msg = ("'%s' is not a valid type for the '%s' attribute "
                   "declaration on %s")
            raise TypeError(msg % (attr_type, name, self))

    def bind_expression(self, name, expression, notify_only=False):
        """ Binds the given expression to the attribute 'name'.
         
        If the attribute does not exist, an exception is raised. A 
        strong reference to the expression object is kept internally.
        If the expression is not notify_only and the object is already
        fully initialized, the value of the expression will be applied
        immediately.

        Parameters
        ----------
        name : string
            The name of the attribute on which to bind the expression.
        
        expression : AbstractExpression
            A concrete implementation of AbstractExpression.
        
        notify_only : bool, optional
            If True, the expression is only a notifier, in which case
            multiple binding is allowed, otherwise the new expression
            overrides any old non-notify expression. Defaults to False.

        """
        curr = self.trait(name)
        if curr is None or curr.trait_type is Disallow:
            msg = "Cannot bind expression. %s object has no attribute '%s'"
            raise AttributeError(msg % (self, name))

        # If this is the first time an expression is being bound to the
        # given attribute, then we hook up a change handler. This ensures
        # that we only get one notification event per bound attribute.
        # We also create the notification entry in the dict, which is 
        # a list with at least one item. The first item will always be
        # the left associative expression (or None) and all following
        # items will be the notify_only expressions.
        expressions = self._expressions
        if name not in expressions:
            self.on_trait_change(self._on_bound_attr_changed, name)
            expressions[name] = [None]

        # There can be multiple notify_only expressions bound to a 
        # single attribute, so they just get appended to the end of
        # the list. Otherwise, the left associative expression gets
        # placed at the zero position of the list, overriding any
        # existing expression.
        if notify_only:
            expressions[name].append(expression)
        else:
            handler = self._on_expression_changed
            old = expressions[name][0]
            if old is not None:
                old.expression_changed.disconnect(handler)
            expression.expression_changed.connect(handler)
            expressions[name][0] = expression
        
            # Hookup support for default value computation.
            if not self.initialized:
                # We only need to add an ExpressionTrait once, since it 
                # will reach back into the _expressions dict as needed
                # and retrieve the bound expression.
                if not isinstance(curr.trait_type, ExpressionTrait):
                    self.add_trait(name, ExpressionTrait(curr))
            else:
                # If the component is already initialized, and the given
                # expression supports evaluation, update the attribute 
                # with the current value.
                val = expression.eval()
                if val is not NotImplemented:
                    setattr(self, name, val)

    def _on_expression_changed(self, expression, name, value):
        """ A private signal callback for the expression_changed signal
        of the bound expressions. It updates the value of the attribute
        with the new value from the expression.

        """
        setattr(self, name, value)
    
    def _on_bound_attr_changed(self, obj, name, old, new):
        """ A private handler which is called when any attribute which
        has a bound signal changes. It calls the notify method on each
        of the expressions bound to that attribute, but only after the
        component has been fully initialized.

        """
        # The check for None is for the case where there are no left 
        # associative expressions bound to the attribute, so the first
        # entry in the list is still None.
        if self.initialized:
            for expr in self._expressions[name]:
                if expr is not None:
                    expr.notify(old, new)

    #--------------------------------------------------------------------------
    # Auxiliary Methods 
    #--------------------------------------------------------------------------
    def when(self, switch):
        """ A method which returns itself or None based on the truth of
        the argument.

        This can be useful to easily turn off the effects of a component
        if various situations such as constraints-based layout.

        Parameters
        ----------
        switch : bool
            A boolean which indicates whether the instance or None 
            should be returned.
        
        Returns
        -------
        result : self or None
            If 'switch' is boolean True, self is returned. Otherwise,
            None is returned.

        """
        if switch:
            return self
    
    def traverse(self, depth_first=False):
        """ Yields all of the nodes in the tree, from this node downward.

        Parameters
        ----------
        depth_first : bool, optional
            If True, yield the nodes in depth first order. If False,
            yield the nodes in breadth first order. Defaults to False.

        """
        if depth_first:
            stack = [self]
            stack_pop = stack.pop
            stack_extend = stack.extend
        else:
            stack = deque([self])
            stack_pop = stack.popleft
            stack_extend = stack.extend

        while stack:
            item = stack_pop()
            yield item
            stack_extend(item.children)
    
    def traverse_ancestors(self, root=None):
        """ Yields all of the nodes in the tree, from the parent of this 
        node updward, stopping at the given root.

        Parameters
        ----------
        root : BaseComponent, optional
            The component at which to stop the traversal. Defaults
            to None

        """
        parent = self.parent
        while parent is not root and parent is not None:
            yield parent
            parent = parent.parent

    def find_by_name(self, name):
        """ Locate and return a named item that exists in the subtree
        which starts at this node.

        This method will traverse the tree of components, breadth first,
        from this point downward, looking for a component with the given
        name. The first one with the given name is returned, or None if
        no component is found.

        Parameters
        ----------
        name : string
            The name of the component for which to search.
        
        Returns
        -------
        result : BaseComponent or None
            The first component found with the given name, or None if 
            no component is found.
        
        """
        for cmpnt in self.traverse():
            if cmpnt.name == name:
                return cmpnt

    def toplevel_component(self):
        """ Walks up the tree of components starting at this node and
        returns the toplevel node, which is the first node encountered
        without a parent.

        """
        cmpnt = self
        while cmpnt is not None:
            res = cmpnt
            cmpnt = cmpnt.parent
        return res
Пример #21
0
class BaselineView(HasTraits):
  python_console_cmds = Dict()

  ns = List()
  es = List()
  ds = List()

  table = List()

  plot = Instance(Plot)
  plot_data = Instance(ArrayPlotData)

  running = Bool(True)
  position_centered = Bool(False)

  clear_button = SVGButton(
    label='', tooltip='Clear',
    filename=os.path.join(os.path.dirname(__file__), 'images', 'iconic', 'x.svg'),
    width=16, height=16
  )
  zoomall_button = SVGButton(
    label='', tooltip='Zoom All',
    filename=os.path.join(os.path.dirname(__file__), 'images', 'iconic', 'fullscreen.svg'),
    width=16, height=16
  )
  center_button = SVGButton(
    label='', tooltip='Center on Baseline', toggle=True,
    filename=os.path.join(os.path.dirname(__file__), 'images', 'iconic', 'target.svg'),
    width=16, height=16
  )
  paused_button = SVGButton(
    label='', tooltip='Pause', toggle_tooltip='Run', toggle=True,
    filename=os.path.join(os.path.dirname(__file__), 'images', 'iconic', 'pause.svg'),
    toggle_filename=os.path.join(os.path.dirname(__file__), 'images', 'iconic', 'play.svg'),
    width=16, height=16
  )

  reset_button = Button(label='Reset Filters')
  reset_iar_button = Button(label='Reset IAR')
  init_base_button = Button(label='Init. with known baseline')

  traits_view = View(
    HSplit(
      Item('table', style = 'readonly', editor = TabularEditor(adapter=SimpleAdapter()), show_label=False, width=0.3),
      VGroup(
        HGroup(
          Item('paused_button', show_label=False),
          Item('clear_button', show_label=False),
          Item('zoomall_button', show_label=False),
          Item('center_button', show_label=False),
          Item('reset_button', show_label=False),
          Item('reset_iar_button', show_label=False),
          Item('init_base_button', show_label=False),
        ),
        Item(
          'plot',
          show_label = False,
          editor = ComponentEditor(bgcolor = (0.8,0.8,0.8)),
        )
      )
    )
  )

  def _zoomall_button_fired(self):
    self.plot.index_range.low_setting = 'auto'
    self.plot.index_range.high_setting = 'auto'
    self.plot.value_range.low_setting = 'auto'
    self.plot.value_range.high_setting = 'auto'

  def _center_button_fired(self):
    self.position_centered = not self.position_centered

  def _paused_button_fired(self):
    self.running = not self.running

  def _reset_button_fired(self):
    self.link.send_message(sbp_messages.RESET_FILTERS, '\x00')

  def _reset_iar_button_fired(self):
    self.link.send_message(sbp_messages.RESET_FILTERS, '\x01')

  def _init_base_button_fired(self):
    self.link.send_message(sbp_messages.INIT_BASE, '')

  def _clear_button_fired(self):
    self.ns = []
    self.es = []
    self.ds = []
    self.plot_data.set_data('n', [])
    self.plot_data.set_data('e', [])
    self.plot_data.set_data('d', [])
    self.plot_data.set_data('t', [])

  def _baseline_callback_ecef(self, data):
    #Don't do anything for ECEF currently
    return

  def iar_state_callback(self, data):
    self.num_hyps = struct.unpack('<I', data)

  def _baseline_callback_ned(self, data):
    # Updating an ArrayPlotData isn't thread safe (see chaco issue #9), so
    # actually perform the update in the UI thread.
    if self.running:
      GUI.invoke_later(self.baseline_callback, data)

  def update_table(self):
    self._table_list = self.table.items()

  def baseline_callback(self, data):
    soln = sbp_messages.BaselineNED(data)

    soln.n = soln.n * 1e-3
    soln.e = soln.e * 1e-3
    soln.d = soln.d * 1e-3

    dist = np.sqrt(soln.n**2 + soln.e**2 + soln.d**2)

    table = []

    table.append(('N', soln.n))
    table.append(('E', soln.e))
    table.append(('D', soln.d))
    table.append(('Dist.', dist))
    table.append(('Num. Sats.', soln.n_sats))
    table.append(('Flags', hex(soln.flags)))
    if soln.flags & 1:
      table.append(('Mode', 'Fixed RTK'))
    else:
      table.append(('Mode', 'Float'))
    table.append(('IAR Num. Hyps.', self.num_hyps))

    if self.log_file is None:
      self.log_file = open(time.strftime("baseline_log_%Y%m%d-%H%M%S.csv"), 'w')

    self.log_file.write('%.2f,%.4f,%.4f,%.4f,%d\n' % (soln.tow, soln.n, soln.e, soln.d, soln.n_sats))
    self.log_file.flush()

    self.ns.append(soln.n)
    self.es.append(soln.e)
    self.ds.append(soln.d)

    self.ns = self.ns[-1000:]
    self.es = self.es[-1000:]
    self.ds = self.ds[-1000:]

    self.plot_data.set_data('n', self.ns)
    self.plot_data.set_data('e', self.es)
    self.plot_data.set_data('d', self.ds)
    self.plot_data.set_data('ref_n', [0.0, soln.n])
    self.plot_data.set_data('ref_e', [0.0, soln.e])
    self.plot_data.set_data('ref_d', [0.0, soln.d])
    t = range(len(self.ns))
    self.plot_data.set_data('t', t)

    if self.position_centered:
      d = (self.plot.index_range.high - self.plot.index_range.low) / 2.
      self.plot.index_range.set_bounds(soln.n - d, soln.n + d)
      d = (self.plot.value_range.high - self.plot.value_range.low) / 2.
      self.plot.value_range.set_bounds(soln.e - d, soln.e + d)

    self.table = table

  def __init__(self, link):
    super(BaselineView, self).__init__()

    self.log_file = None

    self.num_hyps = 0

    self.plot_data = ArrayPlotData(n=[0.0], e=[0.0], d=[0.0], t=[0.0], ref_n=[0.0], ref_e=[0.0], ref_d=[0.0])
    self.plot = Plot(self.plot_data)

    self.plot.plot(('e', 'n'), type='line', name='line', color=(0, 0, 0, 0.1))
    self.plot.plot(('e', 'n'), type='scatter', name='points', color='blue', marker='dot', line_width=0.0, marker_size=1.0)
    self.plot.plot(('ref_e', 'ref_n'),
        type='scatter',
        color='red',
        marker='plus',
        marker_size=5,
        line_width=1.5
    )

    self.plot.index_axis.tick_label_position = 'inside'
    self.plot.index_axis.tick_label_color = 'gray'
    self.plot.index_axis.tick_color = 'gray'
    self.plot.value_axis.tick_label_position = 'inside'
    self.plot.value_axis.tick_label_color = 'gray'
    self.plot.value_axis.tick_color = 'gray'
    self.plot.padding = (0, 1, 0, 1)

    self.plot.tools.append(PanTool(self.plot))
    zt = ZoomTool(self.plot, zoom_factor=1.1, tool_mode="box", always_on=False)
    self.plot.overlays.append(zt)

    self.link = link
    self.link.add_callback(sbp_messages.SBP_BASELINE_NED, self._baseline_callback_ned)
    self.link.add_callback(sbp_messages.SBP_BASELINE_ECEF, self._baseline_callback_ecef)
    self.link.add_callback(sbp_messages.IAR_STATE, self.iar_state_callback)

    self.python_console_cmds = {
      'baseline': self
    }
Пример #22
0
class SystemMonitorView(HasTraits):
    python_console_cmds = Dict()

    _threads_table_list = List()
    threads = List()
    uart_a_crc_error_count = Int(0)
    uart_a_io_error_count = Int(0)
    uart_a_rx_buffer = Float(0)
    uart_a_tx_buffer = Float(0)
    uart_a_tx_KBps = Float(0)
    uart_a_rx_KBps = Float(0)

    uart_b_crc_error_count = Int(0)
    uart_b_io_error_count = Int(0)
    uart_b_rx_buffer = Float(0)
    uart_b_tx_buffer = Float(0)
    uart_b_tx_KBps = Float(0)
    uart_b_rx_KBps = Float(0)

    ftdi_crc_error_count = Int(0)
    ftdi_io_error_count = Int(0)
    ftdi_rx_buffer = Float(0)
    ftdi_tx_buffer = Float(0)
    ftdi_tx_KBps = Float(0)
    ftdi_rx_KBps = Float(0)

    msg_obs_avg_latency_ms = Int(0)
    msg_obs_min_latency_ms = Int(0)
    msg_obs_max_latency_ms = Int(0)
    msg_obs_window_latency_ms = Int(0)

    msg_obs_avg_period_ms = Int(0)
    msg_obs_min_period_ms = Int(0)
    msg_obs_max_period_ms = Int(0)
    msg_obs_window_period_ms = Int(0)

    piksi_reset_button = SVGButton(label='Reset Piksi',
                                   tooltip='Reset Piksi',
                                   filename=os.path.join(
                                       determine_path(), 'images',
                                       'fontawesome', 'power27.svg'),
                                   width=16,
                                   height=16,
                                   aligment='center')

    traits_view = View(
        VGroup(
            Item(
                '_threads_table_list',
                style='readonly',
                editor=TabularEditor(adapter=SimpleAdapter()),
                show_label=False,
                width=0.85,
            ),
            HGroup(
                VGroup(
                    HGroup(VGroup(Item('msg_obs_window_latency_ms',
                                       label='Curr',
                                       style='readonly',
                                       format_str='%dms'),
                                  Item('msg_obs_avg_latency_ms',
                                       label='Avg',
                                       style='readonly',
                                       format_str='%dms'),
                                  Item('msg_obs_min_latency_ms',
                                       label='Min',
                                       style='readonly',
                                       format_str='%dms'),
                                  Item('msg_obs_max_latency_ms',
                                       label='Max',
                                       style='readonly',
                                       format_str='%dms'),
                                  label='Latency',
                                  show_border=True),
                           VGroup(
                               Item('msg_obs_window_period_ms',
                                    label='Curr',
                                    style='readonly',
                                    format_str='%dms'),
                               Item('msg_obs_avg_period_ms',
                                    label='Avg',
                                    style='readonly',
                                    format_str='%dms'),
                               Item('msg_obs_min_period_ms',
                                    label='Min',
                                    style='readonly',
                                    format_str='%dms'),
                               Item('msg_obs_max_period_ms',
                                    label='Max',
                                    style='readonly',
                                    format_str='%dms'),
                               label='Period',
                               show_border=True,
                           ),
                           show_border=True,
                           label="Observation Connection Monitor"),
                    HGroup(
                        Spring(width=50, springy=False),
                        Item('piksi_reset_button',
                             show_label=False,
                             width=0.50),
                    ),
                ),
                VGroup(
                    Item('uart_a_crc_error_count',
                         label='CRC Errors',
                         style='readonly'),
                    Item('uart_a_io_error_count',
                         label='IO Errors',
                         style='readonly'),
                    Item('uart_a_tx_buffer',
                         label='TX Buffer %',
                         style='readonly',
                         format_str='%.1f'),
                    Item('uart_a_rx_buffer',
                         label='RX Buffer %',
                         style='readonly',
                         format_str='%.1f'),
                    Item('uart_a_tx_KBps',
                         label='TX KBytes/s',
                         style='readonly',
                         format_str='%.2f'),
                    Item('uart_a_rx_KBps',
                         label='RX KBytes/s',
                         style='readonly',
                         format_str='%.2f'),
                    label='UART A',
                    show_border=True,
                ),
                VGroup(
                    Item('uart_b_crc_error_count',
                         label='CRC Errors',
                         style='readonly'),
                    Item('uart_b_io_error_count',
                         label='IO Errors',
                         style='readonly'),
                    Item('uart_b_tx_buffer',
                         label='TX Buffer %',
                         style='readonly',
                         format_str='%.1f'),
                    Item('uart_b_rx_buffer',
                         label='RX Buffer %',
                         style='readonly',
                         format_str='%.1f'),
                    Item('uart_b_tx_KBps',
                         label='TX KBytes/s',
                         style='readonly',
                         format_str='%.2f'),
                    Item('uart_b_rx_KBps',
                         label='RX KBytes/s',
                         style='readonly',
                         format_str='%.2f'),
                    label='UART B',
                    show_border=True,
                ),
                VGroup(
                    Item('ftdi_crc_error_count',
                         label='CRC Errors',
                         style='readonly'),
                    Item('ftdi_io_error_count',
                         label='IO Errors',
                         style='readonly'),
                    Item('ftdi_tx_buffer',
                         label='TX Buffer %',
                         style='readonly',
                         format_str='%.1f'),
                    Item('ftdi_rx_buffer',
                         label='RX Buffer %',
                         style='readonly',
                         format_str='%.1f'),
                    Item('ftdi_tx_KBps',
                         label='TX KBytes/s',
                         style='readonly',
                         format_str='%.2f'),
                    Item('ftdi_rx_KBps',
                         label='RX KBytes/s',
                         style='readonly',
                         format_str='%.2f'),
                    label='USB UART',
                    show_border=True,
                ),
            ),
        ), )

    def update_threads(self):
        self._threads_table_list = [
            (thread_name, state.cpu, state.stack_free)
            for thread_name, state in sorted(
                self.threads, key=lambda x: x[1].cpu, reverse=True)
        ]

    def heartbeat_callback(self, sbp_msg, **metadata):
        if self.threads != []:
            self.update_threads()
            self.threads = []

    def thread_state_callback(self, sbp_msg, **metadata):
        if sbp_msg.name == '':
            sbp_msg.name = '(no name)'
        sbp_msg.cpu /= 10.
        self.threads.append((sbp_msg.name, sbp_msg))

    def _piksi_reset_button_fired(self):
        self.link(MsgReset(flags=0))

    def uart_state_callback(self, m, **metadata):
        self.uart_a_tx_KBps = m.uart_a.tx_throughput
        self.uart_a_rx_KBps = m.uart_a.rx_throughput
        self.uart_a_crc_error_count = m.uart_a.crc_error_count
        self.uart_a_io_error_count = m.uart_a.io_error_count
        self.uart_a_tx_buffer = 100 * m.uart_a.tx_buffer_level / 255.0
        self.uart_a_rx_buffer = 100 * m.uart_a.rx_buffer_level / 255.0

        self.uart_b_tx_KBps = m.uart_b.tx_throughput
        self.uart_b_rx_KBps = m.uart_b.rx_throughput
        self.uart_b_crc_error_count = m.uart_b.crc_error_count
        self.uart_b_io_error_count = m.uart_b.io_error_count
        self.uart_b_tx_buffer = 100 * m.uart_b.tx_buffer_level / 255.0
        self.uart_b_rx_buffer = 100 * m.uart_b.rx_buffer_level / 255.0

        self.uart_ftdi_tx_KBps = m.uart_ftdi.tx_throughput
        self.uart_ftdi_rx_KBps = m.uart_ftdi.rx_throughput
        self.uart_ftdi_crc_error_count = m.uart_ftdi.crc_error_count
        self.uart_ftdi_io_error_count = m.uart_ftdi.io_error_count
        self.uart_ftdi_tx_buffer = 100 * m.uart_ftdi.tx_buffer_level / 255.0
        self.uart_ftdi_rx_buffer = 100 * m.uart_ftdi.rx_buffer_level / 255.0

        self.msg_obs_avg_latency_ms = m.latency.avg
        self.msg_obs_min_latency_ms = m.latency.lmin
        self.msg_obs_max_latency_ms = m.latency.lmax
        self.msg_obs_window_latency_ms = m.latency.current
        if m.msg_type == SBP_MSG_UART_STATE:
            self.msg_obs_avg_period_ms = m.obs_period.avg
            self.msg_obs_min_period_ms = m.obs_period.pmin
            self.msg_obs_max_period_ms = m.obs_period.pmax
            self.msg_obs_window_period_ms = m.obs_period.current

    def __init__(self, link):
        super(SystemMonitorView, self).__init__()
        self.link = link
        self.link.add_callback(self.heartbeat_callback, SBP_MSG_HEARTBEAT)
        self.link.add_callback(self.thread_state_callback,
                               SBP_MSG_THREAD_STATE)
        self.link.add_callback(self.uart_state_callback,
                               [SBP_MSG_UART_STATE, SBP_MSG_UART_STATE_DEPA])

        self.python_console_cmds = {'mon': self}
Пример #23
0
class InstanceFactoryChoice(InstanceChoiceItem):

    # -------------------------------------------------------------------------
    #  Trait definitions:
    # -------------------------------------------------------------------------

    #: Indicates whether an instance compatible with this item can be dragged
    #: and dropped rather than created
    droppable = Bool(False)

    #: Indicates whether the item can be selected by the user
    selectable = Bool(True)

    #: A class (or other callable) that can be used to create an item
    #: compatible with this item
    klass = Callable()

    #: Tuple of arguments to pass to **klass** to create an instance
    args = Tuple()

    #: Dictionary of arguments to pass to **klass** to create an instance
    kw_args = Dict(Str, Any)

    #: Does this item create new instances? This value overrides the default.
    is_factory = True

    def get_name(self, object=None):
        """ Returns the name of the item.
        """
        if self.name != "":
            return self.name

        name = getattr(object, "name", None)
        if isinstance(name, str):
            return name

        if issubclass(type(self.klass), type):
            klass = self.klass
        else:
            klass = self.get_object().__class__

        return user_name_for(klass.__name__)

    def get_object(self):
        """ Returns the object associated with the item.
        """
        return self.klass(*self.args, **self.kw_args)

    def is_droppable(self):
        """ Indicates whether the item supports drag and drop.
        """
        return self.droppable

    def is_compatible(self, object):
        """ Indicates whether a specified object is compatible with the item.
        """
        if issubclass(type(self.klass), type):
            return isinstance(object, self.klass)
        return isinstance(object, self.get_object().__class__)

    def is_selectable(self):
        """ Indicates whether the item can be selected by the user.
        """
        return self.selectable
Пример #24
0
class DatasetManager(HasTraits):

    # The TVTK dataset we manage.
    dataset = Instance(tvtk.DataSet)

    # Our output, this is the dataset modified by us with different
    # active arrays.
    output = Property(Instance(tvtk.DataSet))

    # The point scalars for the dataset.  You may manipulate the arrays
    # in-place.  However adding new keys in this dict will not set the
    # data in the `dataset` for that you must explicitly call
    # `add_array`.
    point_scalars = Dict(Str, Array)
    # Point vectors.
    point_vectors = Dict(Str, Array)
    # Point tensors.
    point_tensors = Dict(Str, Array)

    # The cell scalars for the dataset.
    cell_scalars = Dict(Str, Array)
    cell_vectors = Dict(Str, Array)
    cell_tensors = Dict(Str, Array)

    # This filter allows us to change the attributes of the data
    # object and will ensure that the pipeline is properly taken care
    # of.  Directly setting the array in the VTK object will not do
    # this.
    _assign_attribute = Instance(tvtk.AssignAttribute,
                                 args=(),
                                 allow_none=False)

    ######################################################################
    # Public interface.
    ######################################################################
    def add_array(self, array, name, category='point'):
        """
        Add an array to the dataset to specified category ('point' or
        'cell').
        """
        assert len(array.shape) <= 2, "Only 2D arrays can be added."
        data = getattr(self.dataset, '%s_data' % category)
        if len(array.shape) == 2:
            assert array.shape[1] in [1, 3, 4, 9], \
                    "Only Nxm arrays where (m in [1,3,4,9]) are supported"
            va = tvtk.to_tvtk(array2vtk(array))
            va.name = name
            data.add_array(va)
            mapping = {1: 'scalars', 3: 'vectors', 4: 'scalars', 9: 'tensors'}
            dict = getattr(self, '%s_%s' % (category, mapping[array.shape[1]]))
            dict[name] = array
        else:
            va = tvtk.to_tvtk(array2vtk(array))
            va.name = name
            data.add_array(va)
            dict = getattr(self, '%s_scalars' % (category))
            dict[name] = array

    def remove_array(self, name, category='point'):
        """Remove an array by its name and optional category (point and
        cell).  Returns the removed array.
        """
        type = self._find_array(name, category)
        data = getattr(self.dataset, '%s_data' % category)
        data.remove_array(name)
        d = getattr(self, '%s_%s' % (category, type))
        return d.pop(name)

    def rename_array(self, name1, name2, category='point'):
        """Rename a particular array from `name1` to `name2`.
        """
        type = self._find_array(name1, category)
        data = getattr(self.dataset, '%s_data' % category)
        arr = data.get_array(name1)
        arr.name = name2
        d = getattr(self, '%s_%s' % (category, type))
        d[name2] = d.pop(name1)

    def activate(self, name, category='point'):
        """Make the specified array the active one.
        """
        type = self._find_array(name, category)
        self._activate_data_array(type, category, name)

    def update(self):
        """Update the dataset when the arrays are changed.
        """
        self.dataset.modified()
        self._assign_attribute.update()

    ######################################################################
    # Non-public interface.
    ######################################################################
    def _dataset_changed(self, value):
        self._setup_data()
        self._assign_attribute.input = value

    def _get_output(self):
        return self._assign_attribute.output

    def _setup_data(self):
        """Updates the arrays from what is available in the input data.
        """
        input = self.dataset
        pnt_attr, cell_attr = get_all_attributes(input)

        self._setup_data_arrays(cell_attr, 'cell')
        self._setup_data_arrays(pnt_attr, 'point')

    def _setup_data_arrays(self, attributes, d_type):
        """Given the dict of the attributes from the
        `get_all_attributes` function and the data type (point/cell)
        data this will setup the object and the data.
        """
        attrs = ['scalars', 'vectors', 'tensors']
        aa = self._assign_attribute
        input = self.dataset
        data = getattr(input, '%s_data' % d_type)
        for attr in attrs:
            values = attributes[attr]
            # Get the arrays from VTK, create numpy arrays and setup our
            # traits.
            arrays = {}
            for name in values:
                va = data.get_array(name)
                npa = va.to_array()
                # Now test if changes to the numpy array are reflected
                # in the VTK array, if they are we are set, else we
                # have to set the VTK array back to the numpy array.
                if len(npa.shape) > 1:
                    old = npa[0, 0]
                    npa[0][0] = old - 1
                    if abs(va[0][0] - npa[0, 0]) > 1e-8:
                        va.from_array(npa)
                    npa[0][0] = old
                else:
                    old = npa[0]
                    npa[0] = old - 1
                    if abs(va[0] - npa[0]) > 1e-8:
                        va.from_array(npa)
                    npa[0] = old
                arrays[name] = npa

            setattr(self, '%s_%s' % (d_type, attr), arrays)

    def _activate_data_array(self, data_type, category, name):
        """Activate (or deactivate) a particular array.

        Given the nature of the data (scalars, vectors etc.) and the
        type of data (cell or points) it activates the array given by
        its name.

        Parameters:
        -----------

        data_type: one of 'scalars', 'vectors', 'tensors'
        category: one of 'cell', 'point'.
        name: string of array name to activate.
        """
        input = self.dataset
        data = None
        data = getattr(input, category + '_data')
        method = getattr(data, 'set_active_%s' % data_type)
        if len(name) == 0:
            # If the value is empty then we deactivate that attribute.
            method(None)
        else:
            aa = self._assign_attribute
            method(name)
            aa.assign(name, data_type.upper(), category.upper() + '_DATA')
            aa.update()

    def _find_array(self, name, category='point'):
        """Return information on which kind of attribute contains the
        specified named array in a particular category."""
        types = ['scalars', 'vectors', 'tensors']
        for type in types:
            attr = '%s_%s' % (category, type)
            d = getattr(self, attr)
            if name in d.keys():
                return type
        raise KeyError('No %s array named %s available in dataset' %
                       (category, name))
Пример #25
0
class ExperimentDialogModel(HasStrictTraits):
    """
    The model for the Experiment setup dialog.
    """

    # the list of Tubes (rows in the table)
    tubes = List(Tube)

    # a list of the traits that have been added to Tube instances
    # (columns in the table)
    tube_traits = List(TubeTrait)
    tube_traits_dict = Dict

    # keeps track of whether a tube is unique or not
    counter = Dict(Int, Int)

    # are all the tubes unique and filled?
    valid = Property(List)

    # a dummy Experiment, with the first Tube and no events, so we can check
    # subsequent tubes for voltage etc. and fail early.
    dummy_experiment = Instance(Experiment)

    # traits to communicate with the traits_view
    fcs_metadata = Property(List, depends_on='tubes')

    def init(self, import_op):

        if 'CF_File' not in import_op.conditions:
            self.tube_traits.append(
                TubeTrait(model=self, type='metadata', name='CF_File'))

        for name, condition in import_op.conditions.items():
            if condition == "category" or condition == "object":
                self.tube_traits.append(
                    TubeTrait(model=self, name=name, type='category'))
            elif condition == "int" or condition == "float":
                self.tube_traits.append(
                    TubeTrait(model=self, name=name, type='float'))
            elif condition == "bool":
                self.tube_traits.append(
                    TubeTrait(model=self, name=name, type='bool'))

        self.dummy_experiment = None

        if import_op.tubes:
            try:
                self.dummy_experiment = import_op.apply(metadata_only=True,
                                                        force=True)
            except Exception as e:
                warning(
                    None, "Had trouble loading some of the experiment's FCS "
                    "files.  You will need to re-add them.\n\n{}".format(
                        str(e)))
                return

            for op_tube in import_op.tubes:
                metadata = self.dummy_experiment.metadata['fcs_metadata'][
                    op_tube.file]
                tube = Tube(file=op_tube.file,
                            parent=self,
                            metadata=sanitize_metadata(metadata))

                self.tubes.append(tube)  # adds the dynamic traits to the tube

                tube.trait_set(**op_tube.conditions)

                for trait in self.tube_traits:
                    if trait.type == 'metadata':
                        tube.trait_set(
                            **{trait.name: tube.metadata[trait.name]})
                    else:
                        tube.conditions[trait.name] = tube.trait_get()[
                            trait.name]

    @on_trait_change('tubes_items')
    def _tubes_items(self, event):
        for tube in event.added:
            for trait in self.tube_traits:
                if not trait.name:
                    continue

                tube.add_trait(trait.name, trait.trait)

                if trait.type == 'metadata':
                    tube.trait_set(**{trait.name: tube.metadata[trait.name]})
                else:
                    tube.trait_set(**{trait.name: trait.trait.default_value})
                    tube.conditions[trait.name] = tube.trait_get()[trait.name]

        self.counter.clear()
        for tube in self.tubes:
            tube_hash = tube.conditions_hash()
            if tube_hash in self.counter:
                self.counter[tube_hash] += 1
            else:
                self.counter[tube_hash] = 1

    @on_trait_change('tube_traits_items')
    def _tube_traits_changed(self, event):
        for trait in event.added:
            if not trait.name:
                continue

            for tube in self.tubes:
                tube.add_trait(trait.name, trait.trait)

                if trait.type == 'metadata':
                    tube.trait_set(**{trait.name: tube.metadata[trait.name]})
                else:
                    tube.trait_set(**{trait.name: trait.trait.default_value})
                    tube.conditions[trait.name] = tube.trait_get()[trait.name]

            self.tube_traits_dict[trait.name] = trait

        for trait in event.removed:
            if not trait.name:
                continue

            for tube in self.tubes:
                tube.remove_trait(trait.name)

                if trait.type != 'metadata':
                    del tube.conditions[trait.name]

            del self.tube_traits_dict[trait.name]

        self.counter.clear()
        for tube in self.tubes:
            tube_hash = tube.conditions_hash()
            if tube_hash in self.counter:
                self.counter[tube_hash] += 1
            else:
                self.counter[tube_hash] = 1

    @on_trait_change('tube_traits:name')
    def _on_trait_name_change(self, trait, _, old_name, new_name):
        for tube in self.tubes:
            trait_value = None

            if old_name:
                # store the value
                trait_value = tube.trait_get()[old_name]

                if trait.type != 'metadata':
                    del tube.conditions[old_name]

                # defer removing the old trait until the handler
                # tube.remove_trait(old_name)

            if new_name:
                if new_name in tube.metadata:
                    trait_value = tube.metadata[new_name]
                elif trait_value is None:
                    trait_value = trait.trait.default_value

                tube.add_trait(new_name, trait.trait)
                tube.trait_set(**{new_name: trait_value})

                if trait.type != 'metadata':
                    tube.conditions[new_name] = trait_value

        if old_name:
            del self.tube_traits_dict[old_name]

        if new_name:
            self.tube_traits_dict[new_name] = trait

        self.counter.clear()
        for tube in self.tubes:
            tube_hash = tube.conditions_hash()
            if tube_hash in self.counter:
                self.counter[tube_hash] += 1
            else:
                self.counter[tube_hash] = 1

    @on_trait_change('tube_traits:type')
    def _on_type_change(self, trait, name, old_type, new_type):
        if not trait.name:
            return

        for tube in self.tubes:
            trait_value = tube.trait_get()[trait.name]
            tube.remove_trait(trait.name)
            if old_type != 'metadata':
                del tube.conditions[trait.name]

            tube.add_trait(trait.name, trait.trait)

            try:
                tube.trait_set(**{trait.name: trait_value})
            except TraitError:
                tube.trait_set(**{trait.name: trait.trait.default_value})
            if new_type != 'metadata':
                tube.conditions[trait.name] = tube.trait_get()[trait.name]

        self.counter.clear()
        for tube in self.tubes:
            tube_hash = tube.conditions_hash()
            if tube_hash in self.counter:
                self.counter[tube_hash] += 1
            else:
                self.counter[tube_hash] = 1

    def update_import_op(self, import_op):
        if not self.tubes:
            return

        assert self.dummy_experiment is not None

        conditions = {
            trait.name: trait.type
            for trait in self.tube_traits if trait.type != 'metadata'
        }

        tubes = []
        events = 0
        for tube in self.tubes:
            op_tube = CytoflowTube(file=tube.file,
                                   conditions=tube.trait_get(
                                       list(conditions.keys())))
            tubes.append(op_tube)
            events += tube.metadata['TOT']

        import_op.ret_events = events

        import_op.conditions = conditions
        import_op.tubes = tubes
        import_op.original_channels = channels = self.dummy_experiment.channels

        all_present = len(import_op.channels_list) > 0
        if len(import_op.channels_list) > 0:
            for c in import_op.channels_list:
                if c.name not in channels:
                    all_present = False

            if not all_present:
                warning(
                    None, "Some of the operation's channels weren't found in "
                    "these FCS files.  Resetting all channel names.",
                    "Resetting channel names")

        if not all_present:
            import_op.reset_channels()

    def is_tube_unique(self, tube):
        tube_hash = tube.conditions_hash()
        if tube_hash in self.counter:
            return self.counter[tube.conditions_hash()] == 1
        else:
            return False

    def _get_valid(self):
        return len(self.tubes) > 0 and \
               len(set(self.counter)) == len(self.tubes) and \
               all([x.all_conditions_set for x in self.tubes])

    # magic: gets the list of FCS metadata for the trait list editor
    @cached_property
    def _get_fcs_metadata(self):
        meta = {}
        for tube in self.tubes:
            for name, val in tube.metadata.items():
                if name not in meta:
                    meta[name] = set()

                meta[name].add(val)

        ret = [x for x in meta.keys() if len(meta[x]) > 1]

        return sorted(ret)
Пример #26
0
class ProjectInfo(HasTraits):

    data_dir = Str(
        "../data",
        desc=dedent("""
        A relative path to the directory where raw data is stored.
        """),
    )
    proc_dir = Str(
        "../proc",
        desc=dedent("""
        A relative path to the directory where lyman workflows will output
        persistent data.
        """),
    )
    cache_dir = Str(
        "../cache",
        desc=dedent("""
        A relative path to the directory where lyman workflows will write
        intermediate files during execution.
        """),
    )
    remove_cache = Bool(
        True,
        desc=dedent("""
        If True, delete the cache directory containing intermediate files after
        successful execution of the workflow. This behavior can be overridden
        at runtime by command-line arguments.
        """),
    )
    fm_template = Str(
        "{session}_fieldmap_{encoding}.nii.gz",
        desc=dedent("""
        A template string to identify session-specific fieldmap files.
        """),
    )
    ts_template = Str(
        "{session}_{experiment}_{run}.nii.gz",
        desc=dedent("""
        A template string to identify time series data files.
        """),
    )
    sb_template = Str(
        "{session}_{experiment}_{run}_ref.nii.gz",
        desc=dedent("""
        A template string to identify reference volumes corresponding to each
        run of time series data.
        """),
    )
    voxel_size = Tuple(
        Float(2), Float(2), Float(2),
        desc=dedent("""
        The voxel size to use for the functional template.
        """),
    )
    phase_encoding = Enum(
        "pa", "ap",
        desc=dedent("""
        The phase encoding direction used in the functional acquisition.
        """),
    )
    scan_info = Dict(
        Str, Dict(Str, Dict(Str, List(Str))),
        desc=dedent("""
        Information about scanning sessions, populted by reading the
        ``scan_info.yaml`` file.
        """),
    )
Пример #27
0
class DemoPath(DemoTreeNodeObject):
    """ This class represents a directory.
    """

    # -------------------------------------------------------------------------
    #  Trait definitions:
    # -------------------------------------------------------------------------

    #: Parent of this package:
    parent = Any()

    #: Name of file system path to this package:
    path = Property(depends_on='parent.path,name')

    #: Description of what the demo does:
    description = Property(HTML, depends_on="path,css_filename")

    #: The base URL for links:
    base_url = Property(depends_on='path')

    #: The css file for this node.
    css_filename = Str("default.css")

    #: Name of the directory:
    name = Str()

    #: UI form of the 'name':
    nice_name = Property()

    #: Dictionary containing symbols defined by the path's '__init__.py' file:
    init_dic = Property()

    #: Should .py files be included?
    use_files = Bool(True)

    #: Paths do allow children:
    allows_children = Bool(True)

    #: Configuration dictionary for this node
    #: This trait is set when a config file exists for the parent of this path.
    config_dict = Dict()

    #: Configuration file for this node.
    config_filename = Str()

    #: The css file for this node.
    css_filename = Str("default.css")

    #: Shadow trait for description property
    _description = Str()

    #: Cached value of the nice_name property.
    _nice_name = Str()

    #: Dictionary mapping file extensions to callables
    _file_factory = Dict()

    def __file_factory_default(self):
        return {
            ".htm": DemoContentFile,
            ".html": DemoContentFile,
            ".jpeg": DemoImageFile,
            ".jpg": DemoImageFile,
            ".png": DemoImageFile,
            ".py": DemoFile,
            ".rst": DemoContentFile,
            ".txt": DemoContentFile
        }

    # -------------------------------------------------------------------------
    #  Implementation of the 'path' property:
    # -------------------------------------------------------------------------

    def _get_path(self):
        if self.parent is not None:
            path = join(self.parent.path, self.name)
        else:
            path = self.name

        return path

    def _get_base_url(self):
        if isdir(self.path):
            base_dir = self.path
        else:
            base_dir = dirname(self.path)
        return base_dir

    # -------------------------------------------------------------------------
    #  Implementation of the 'nice_name' property:
    # -------------------------------------------------------------------------

    def _get_nice_name(self):
        if not self._nice_name:
            self._nice_name = user_name_for(self.name)
        return self._nice_name

    # -------------------------------------------------------------------------
    #  Setter for the 'nice_name' property:
    # -------------------------------------------------------------------------

    def _set_nice_name(self, value):
        old = self.nice_name
        self._nice_name = value
        self.trait_property_changed("nice_name", old, value)

    # -------------------------------------------------------------------------
    #  Implementation of the 'description' property:
    # -------------------------------------------------------------------------

    @cached_property
    def _get_description(self):
        index_rst = os.path.join(self.path, DESCRIPTION_RST_FILENAME)
        if os.path.exists(index_rst):
            with open(index_rst, "r", encoding="utf-8") as f:
                description = f.read()
        else:
            description = ""

        if self.css_filename:
            result = publish_html_str(description, self.css_filename)
        else:
            result = publish_html_str(description)
        return result

    # -------------------------------------------------------------------------
    #  Implementation of the 'init_dic' property:
    # -------------------------------------------------------------------------

    def _get_init_dic(self):
        init_dic = {}
        description, source = parse_source(join(self.path, "__init__.py"))
        exec((exec_str + source), init_dic)
        return init_dic

    # -------------------------------------------------------------------------
    #  Returns whether or not the object has children:
    # -------------------------------------------------------------------------

    def has_children(self):
        """ Returns whether or not the object has children.
        """
        path = self.path
        for name in listdir(path):
            cur_path = join(path, name)
            if isdir(cur_path):
                return True

            if self.use_files:
                name, ext = splitext(name)
                if (ext == ".py") and (name != "__init__"):
                    return True
                elif ext in self._file_factory:
                    return True

        return False

    # -------------------------------------------------------------------------
    #  Gets the object's children:
    # -------------------------------------------------------------------------

    def get_children(self):
        """ Gets the object's children.
        """
        if self.config_dict or self.config_filename:
            children = self.get_children_from_config()
        else:
            children = self.get_children_from_datastructure()
        return children

    # -------------------------------------------------------------------------
    #  Gets the object's children based on the filesystem structure.
    # -------------------------------------------------------------------------
    def get_children_from_datastructure(self):
        """ Gets the object's children based on the filesystem structure.
        """

        dirs = []
        files = []
        path = self.path
        for name in listdir(path):
            cur_path = join(path, name)
            if isdir(cur_path):
                if self.has_py_files(cur_path):
                    dirs.append(
                        DemoPath(
                            parent=self,
                            name=name,
                            css_filename=join('..', self.css_filename)
                        )
                    )
            elif self.use_files:
                if name != "__init__.py":
                    try:
                        demo_file = self._handle_file(name)
                        demo_file.css_filename = self.css_filename
                        files.append(demo_file)
                    except KeyError:
                        pass

        sort_key = operator.attrgetter("name")
        dirs.sort(key=sort_key)
        files.sort(key=sort_key)

        return dirs + files

    # -------------------------------------------------------------------------
    # Gets the object's children as specified in its configuration file or
    # dictionary.
    # -------------------------------------------------------------------------

    def get_children_from_config(self):
        """
        Gets the object's children as specified in its configuration file or
        dictionary.
        """

        if not self.config_dict:
            if exists(self.config_filename):
                try:
                    self.config_dict = ConfigObj(self.config_filename)
                except Exception:
                    pass
        if not self.config_dict:
            return self.get_children_from_datastructure()

        dirs = []
        files = []
        for keyword, value in self.config_dict.items():
            if not value.get("no_demo"):
                sourcedir = value.pop("sourcedir", None)
                if sourcedir is not None:
                    # This is a demo directory.
                    demoobj = DemoPath(
                        parent=self,
                        name=sourcedir,
                        css_filename=join("..", self.css_filename),
                    )
                    demoobj.nice_name = keyword
                    demoobj.config_dict = value
                    dirs.append(demoobj)
                else:
                    names = []
                    filenames = value.pop("files", [])
                    if not isinstance(filenames, list):
                        filenames = [filenames]
                    for filename in filenames:
                        filename = join(self.path, filename)
                        for name in glob.iglob(filename):
                            if basename(name) != "__init__.py":
                                names.append(name)
                    if len(names) > 1:
                        config_dict = {}
                        for name in names:
                            config_dict[basename(name)] = {"files": name}
                        demoobj = DemoPath(parent=self, name="")
                        demoobj.nice_name = keyword
                        demoobj.config_dict = config_dict
                        demoobj.css_filename = os.path.join(
                            "..", self.css_filename)
                        dirs.append(demoobj)
                    elif len(names) == 1:
                        try:
                            demo_file = self._handle_file(name)
                            files.append(demo_file)
                            demo_file.css_filename = self.css_filename
                        except KeyError:
                            pass

        sort_key = operator.attrgetter("nice_name")
        dirs.sort(key=sort_key)
        files.sort(key=sort_key)

        return dirs + files

    # -------------------------------------------------------------------------
    #  Returns whether the specified path contains any .py files:
    # -------------------------------------------------------------------------

    def has_py_files(self, path):
        for name in listdir(path):
            cur_path = join(path, name)
            if isdir(cur_path):
                if self.has_py_files(cur_path):
                    return True

            else:
                name, ext = splitext(name)
                if ext == ".py":
                    return True

        return False

    def _handle_file(self, filename):
        """ Process a file based on its extension.
        """
        _, ext = splitext(filename)
        file_factory = self._file_factory[ext]
        demo_file = file_factory(parent=self, name=filename)
        return demo_file
Пример #28
0
class RunStorage(HasTraits):
    ''' This object takes the file_data_info dictionary from the IRun instance, creates Arrays from this and
	    also maintains the ArrayPlotData objects for the full spectral and temporal data. 
	    Event Hierarchy is as follows:
		
		-file_data_info is updated from "update_storage" method form IRun instance
		-tarray, xarray and twoD_data_full arrays are extracted via property calls to file_data_info
		-trait_change listeners are called that create specdata, timedata plot data sources
		-copies of these are stored in the variable specdata, timedata, xarray, tarray
		-changes to filter variables like "x_avg, t_samp" are applied to the specdata/timedata objects
		 while the specdata, timedata are retained intact unless file_data_info changes

	    Label overrides should force overrides in this class and they will filter down to plots
			'''
    implements(IRunStorage)

    file_data_info = Dict(Str, Tuple(Array(dtype=spec_dtype), Array))

    ### Labels are all forced to strings because chaco requries these straight.  The t_label and x_labels are used to both
    ### store the full names of the data (eg trial1, trial2) and also used as keys in the array plot data object.  In
    ### addition, chacoplot objects need "xarray" and "tarray" objects that correspond to the t_labels/x_labels.  This is because
    ### chaco can't plot a list of strings, so I create intermediate arrays which are the same length as the xarray,tarray values
    ### but simple are set as evenly spaced arrays xarray=(1,2,3,... len(x_label))  THEREFORE the shape, size and events that
    ### have to do with labels control all the redraw events and everything.  The label change events must therefore be considered
    ### very carefully. The plot objects also rely on these labels exclusviely to do sampling and get correct sizes.

    test_override = Button

    x_label = Any
    t_label = Any
    _x_size = Int  #Store the full size of the arrays (used for averaging)
    _t_size = Int  #Autoset when labels updated

    twoD_data_full = Array  #TWO D DATA
    twoD_data_avg = Array  #TWO D DATA after averaging.  Keep both in case needed to compare both plots later
    _valid_x_averages = List  #These are automatically set when twoD data is created
    _valid_t_averages = List

    specdata = Instance(
        ArrayPlotData,
        ())  #Automatically update when twoD_data_full is changed
    timedata = Instance(ArrayPlotData, ())

    dframe = Instance(DataFrame)

    #	plots=List(IPlot)
    plots = Instance(IPlot)

    ### Global sampling traits (Not yet synched up in v3 ###
    x_spacing = Int(1)  #Do I want to do sampling or just averaging or both?
    t_spacing = Int(1)

    ### Global averaging filters filtering traits ###
    t_avg = Enum(values='_valid_t_averages')
    x_avg = Enum(values='_valid_x_averages')
    averaging_style = Enum('Reshaping', 'Rolling')  #Need to add rolling later

    def _test_override_fired(self):
        newx = [str(2 * float(entry)) for entry in self.x_label]
        newt = [str(entry) + 'lol' for entry in self.t_label]
        self.override_labels(xlDabel=newx, x_label=newx, t_label=newt)

    def make_dframe(self):
        self.dframe = DataFrame(self.twoD_data_full, list(self.x_label),
                                list(self.t_label))
        trans = self.dframe.T  #Need to_string to print long rows
        test = PandasPlotData(self.dframe)
        print test.list_data(as_strings=True), 'here'

        ### MAY FIND THAT LOOKUP IS BEST METHOD TO USE FOR SAMPLING OPERATIONS

    def _file_data_info_changed(self):
        ''' This used to be a set of properties, but because twoD data listened to t_label and x_label separately, it would cause a double
		    update which would try to draw incomplete data.  Hence, I have this set up sequentially.  The user can pass new labels via keyword
		    using the override_labels function'''
        self.update_t_label()
        self.update_x_label()
        self.update_full_data()
        self.make_dframe()
        self.update_data_and_plots()

    def update_data_and_plots(self):
        '''Helper method to update full data matrix, and then connected plots.  Used if the underlying
		   matrix changes due to new file, new averaging etc...'''
        self.update_plotdata()  ## FIX LISTENERS
        self.update_plots()

    def _averaging_style_changed(self):
        ''' Merely triggers a double update; however, it will average rows first!!!!'''
        self._x_avg_changed()
        self._t_avg_changed()

    ### Separated averaging so that user can control the order of averaging, and so that array wasn't being
    ### reshaped in both rows and columns if only one dimension was changing
    def _x_avg_changed(self):
        if self.averaging_style == 'Reshaping':
            print 'reshaping x'
            #Return factors to split by, like 1,2,5 for 10 element set corresponding to
            #no averaging, 50% average, 20%averaging.  Necessary because reshaping operations require
            #factors.  So if I have 6 rows to start with, can end with 3 rows, averaging 2 at a time or
            #or 2 rows, averaging 3 at a time.
            self._valid_x_averages = _get_factors(len(self.x_label))

        elif self.averaging_style == 'Rolling':
            #			validx=range(1, self.size/2) #Any valid number between 1 and half sample size
            print 'Need to build the rolling average method'
            pass

        #Row reshape (# rows to remain, row spacing, columns)
        #So if 500 rows and user is averaging by 100, .reshape([5, 100, columns])
        avgarray = self.twoD_data_full.reshape(
            [self._x_size / self.x_avg, self.x_avg,
             self._t_size]).mean(1)  #First avg by rows
        self.twoD_data_avg = avgarray
        self.update_data_and_plots()

    def _t_avg_changed(self):
        if self.averaging_style == 'Reshaping':
            print 'reshaping t'
            self._valid_t_averages = _get_factors(len(self.t_label))

        elif self.averaging_style == 'Rolling':
            #			validt=range(1, self.size/2) #Any valid number between 1 and half sample size
            print 'Need to build the rolling average method'
            pass

        #Col reshape (# rows to remain, row spacing, columns)
        #So if 500 columns and user is averaging by 100, .reshape([100, rows, 5])
        avgarray=self.twoD_data_full.reshape([self.t_avg, self._x_size, \
                       self._t_size/self.t_avg]).mean(2).transpose()  #First avg by rows
        self.twoD_data_avg = avgarray
        self.update_data_and_plots()

    def update_plotdata(self):
        ''' This will create overwrite primary data sources!  Plots are programmed to redraw
		    when these are overwritten.  This should only occur when a global variable is changed
		    all local variabel changes are built into the plot objects already.'''

        #### ALL LISTENERS ARE HOOKED UP, THIS FUNCTION IS PROBABLY CAUSING THE ISSUE... MAYBE
        #### OVERWRITING THE DATA ARRAYS IS CAUSING THIS

        print 'Updating all ctprimary data sources'

        specdata = ArrayPlotData()
        timedata = ArrayPlotData()

        xarray = linspace(0, len(self.x_label), len(self.x_label))
        tarray = linspace(0, len(self.t_label), len(self.t_label))

        specdata.set_data('x', xarray)
        timedata.set_data('x', tarray)  #TIME DATA NOT SET EXCEPT FOR LABE

        for i in range(len(tarray)):
            specdata.set_data(self.t_label[i], self.twoD_data_full[:, i])

        for i in range(len(xarray)):
            timedata.set_data(
                self.x_label[i],
                self.twoD_data_full[i, :])  #LABELS ARE STRINGED AS KEYS

        self.specdata = specdata
        self.timedata = timedata

    ### Set defaults
    def update_plots(self):
        ''' Make list eventually and sync iterably '''
        print 'updating plots from specstorage object'

        if self.plots is None:
            print 'making new plots in spec storage object'

            #		plots=AreaPlot(plothandler=self)
            #	plots=AbsPlot(plothandler=self)
            plots = SpecPlot(plothandler=self)
            #	self.sync_trait('t_spacing', plots, 'spacing', mutual=False) #WHY DOESNT IT WORK?
            self.plots = plots

        self.plots.set_major_data(maindata=self.specdata,
                                  mainlabel=self.t_label)


#		self.plots.set_major_data(maindata=self.timedata, mainlabel=self.x_label)

### Properties that make arrays based on file_data_info dictionary ###

    def update_t_label(self):
        '''Stores the files in a sorted (by name) fashion, used for parsing the data in a sorted manner and also for axis labels'''
        sortlist = self.file_data_info.keys()
        sortlist.sort()
        self.t_label = sortlist
        self._t_size = len(self.t_label)

    def update_x_label(self):
        firstfile = self.t_label[0]
        self.x_label = self.get_wavelengths(firstfile)
        self._x_size = len(self.x_label)

    def override_labels(self, **kwargs):
        ''' Used to let the user pass new labels either individually or all at once to the plot.  Arrays must be the same length as the plot axis (may want
		   to change later).  Arrays will be autoconverted to strings incase users pass an array of floats for example.  This is built that that new
		   traits can easily be passed in the "valid" variable and everything will still work. This also has a feature that allows users to pass
		   wrong keywords and it will let them know which ones it cannot update and will not try to update them.  This is a nice feature over standard
		   error that occurs when **kwargs is called with an invald keyword'''

        valid = {'x_label': self.x_label, 't_label': self.t_label}
        invalid = []

        ##Test for wrong keywords##
        for key in kwargs:
            if key not in valid.keys():
                print '\n\n You entered key\t', key, '\tbut key must be one of the following:\t', \
                              '\t'.join(key for key in valid.keys() ), '\n\n'
                invalid.append(key)

        for key in invalid:
            kwargs.pop(key)  # Catches errors when users input wrong keywords

        ##Make sure new label is same length as old label
        for key in kwargs:
            if len(kwargs[key]) != len(valid[key]):
                print '\n\n You tried to update\t', key, '\tof length\t', len(kwargs[key]), \
                       '\tbut the current value has length\t', len(valid[key])
            else:
                ## Update correct trait, but also makes sure each entry is a string!##
                valid[key] = [str(entry) for entry in kwargs[key]]
                ### REQUIRES SETTING TRAITS!!!
                ### THIS MAY REQUIRE USING TRAIT.SET_ATTR

        print 'updated labels for the following entries:\t', '\t'.join(
            key for key in kwargs.keys()), '\n\n'

        self.update_plotdata()

    def update_full_data(self):
        '''Stores 2-d data for easy input into a multiplot'''
        fullarray = empty((len(self.x_label), len(self.t_label)), dtype=float)
        index = 0
        for afile in self.t_label:  #Iterating over this because this is pre-sorted
            fullarray[:, index] = self.get_intensities(afile)
            index = index + 1
        self.twoD_data_full = fullarray
        if self.averaging_style == 'Reshaping':
            #Return factors to split by, like 1,2,5 for 10 element set corresponding to
            #no averaging, 50% average, 20%averaging.  Necessary because reshaping operations require
            #factors.  So if I have 6 rows to start with, can end with 3 rows, averaging 2 at a time or
            #or 2 rows, averaging 3 at a time.
            self._valid_x_averages = _get_factors(len(self.x_label))
            self._valid_t_averages = _get_factors(len(self.t_label))

        elif self.averaging_style == 'Rolling':
            #			validt=range(1, self.size/2) #Any valid number between 1 and half sample size
            print 'Need to build the rolling average method'
            pass

        #Row reshape (# rows to remain, row spacing, columns)
        #So if 500 rows and user is averaging by 100, .reshape([5, 100, columns])
        avgarray = self.twoD_data_full.reshape(
            [self._x_size / self.x_avg, self.x_avg,
             self._t_size]).mean(1)  #First avg by rows
        #	avgarray=avgarray.reshape().mean(2).transpose() #Then avg by columns (transpose is necessary, see bintest.py)
        print avgarray.shape
        self.twoD_data_avg = avgarray

    ### Simple Return modules to reduce syntax
    def get_wavelengths(self, afile):
        return self.file_data_info[afile][0]['wavelength']

    def get_intensities(self, afile):
        return self.file_data_info[afile][0]['intensity']

    ### Helper Methods

    traits_view = View(Item('t_spacing', label='Global t sampling'),
                       Item('plots', style='custom', show_label=False),
                       Item('test_override'),
                       Item('averaging_style'),
                       Item('x_avg'),
                       Item('t_avg'),
                       resizable=True)
Пример #29
0
class TrackingView(HasTraits):
  python_console_cmds = Dict()
  snrs = Array(dtype=float, shape=(TRACK_N_CHANNELS,))
  snrs_avg = Array(dtype=float, shape=(TRACK_N_CHANNELS,))
  snrs_history = List()
  snr_bars = Instance(Component)

  plot = Instance(Plot)
  plot_data = Instance(ArrayPlotData)

  snr_bar_view = View(
    Item('snr_bars', editor=ComponentEditor(size=(100,100)), show_label=False)
  )

  snr_line_view = View(
    Item(
      'plot',
      editor = ComponentEditor(bgcolor = (0.8,0.8,0.8)),
      show_label = False,
    )
  )

  traits_view = View(
    Item('snrs'),
    Item('snrs_avg')
  )

  def tracking_snrs_callback(self, data):
    fmt = '<' + str(TRACK_N_CHANNELS) + 'f'
    self.snrs = struct.unpack(fmt, data)

  def _snrs_changed(self):
    GUI.invoke_later(self.update_snrs)

  def update_snrs(self):
    self.snrs_history.append(self.snrs)
    s = np.array(self.snrs_history[-100:])
    self.snrs_avg = np.sum(s, axis=0) / 100
    # Filter out channels that are not tracking.
    #self.snrs_avg = map(lambda n, x: x if self.snrs[n] != -1 else -1, self.snrs_avg)
    self.vals.set_data(self.snrs)

    chans = np.transpose(self.snrs_history[-500:])
    t = range(len(chans[0]))
    self.plot_data.set_data('t', t)
    for n in range(TRACK_N_CHANNELS):
      self.plot_data.set_data('ch'+str(n), chans[n])

  def __init__(self, link):
    super(TrackingView, self).__init__()

    self.link = link
    self.link.add_callback(MSG_TRACKING_SNRS, self.tracking_snrs_callback)

    # ======= Line Plot =======

    self.plot_data = ArrayPlotData(t=[0.0])
    self.plot = Plot(self.plot_data, auto_colors=colours_list)
    self.plot.value_range.tight_bounds = False
    self.plot.value_range.low_setting = 0.0
    for n in range(TRACK_N_CHANNELS):
      self.plot_data.set_data('ch'+str(n), [0.0])
      self.plot.plot(('t', 'ch'+str(n)), type='line', color='auto')

    # ======= Bar Plot =======

    idxs = ArrayDataSource(range(1, len(self.snrs)+1))
    self.vals = ArrayDataSource(self.snrs, sort_order='none')
    # Create the index range
    index_range = DataRange1D(idxs, low=0.4, high=TRACK_N_CHANNELS+0.6)
    index_mapper = LinearMapper(range=index_range)
    # Create the value range
    value_range = DataRange1D(low=0.0, high=25.0)
    value_mapper = LinearMapper(range=value_range)

    plot = BarPlot(index=idxs, value=self.vals, 
                   index_mapper=index_mapper, value_mapper=value_mapper, 
                   line_color='blue', fill_color='blue', bar_width=0.8)

    container = OverlayPlotContainer(bgcolor = "white")
    plot.padding = 10
    plot.padding_left = 30
    plot.padding_bottom = 30
    container.add(plot)

    left_axis = PlotAxis(plot, orientation='left')
    bottom_axis = LabelAxis(plot, orientation='bottom',
                           labels = map(str, range(1, TRACK_N_CHANNELS+1)),
                           positions = range(1, TRACK_N_CHANNELS+1),
                           small_haxis_style=True)

    plot.underlays.append(left_axis)
    plot.underlays.append(bottom_axis)

    self.snr_bars = container

    self.python_console_cmds = {
      'track': self
    }
Пример #30
0
class AuxPlot(HasTraits):
    names = List
    _plot_names = List

    save_enabled = Bool
    plot_enabled = Bool
    name = Str(NULL_STR)
    plot_name = Property(Str, depends_on='name')
    scale = Enum('linear', 'log')
    height = Int(100, enter_set=True, auto_set=False)
    x_error = Bool(False)
    y_error = Bool(False)
    ytitle_visible = Bool(True)
    ytick_visible = Bool(True)
    show_labels = Bool(False)

    filter_str = FilterPredicate
    sigma_filter_n = Int
    has_filter = Property(depends_on='filter_str, sigma_filter_n')
    filter_str_tag = Enum(('Omit', 'Outlier', 'Invalid'))
    sigma_filter_tag = Enum(('Omit', 'Outlier', 'Invalid'))

    normalize = None
    use_time_axis = False
    initialized = False

    ymin = Float
    ymax = Float

    ylimits = Tuple(Float, Float, transient=True)
    xlimits = Tuple(Float, Float, transient=True)

    overlay_positions = Dict(transient=True)
    _has_ylimits = Bool(False, transient=True)
    _has_xlimits = Bool(False, transient=True)

    # enabled = True

    marker = Str('circle')
    marker_size = Float(2)

    _suppress = False

    calculated_ymax = Any(transient=True)
    calculated_ymin = Any(transient=True)

    def to_dict(self):
        keys = [k for k in self.traits(transient=False)]
        return {key: getattr(self, key) for key in keys}

    # @on_trait_change('ylimits')
    # def _handle_ylimits(self, new):
    #     print 'fasdfsdf', new
    #     self._suppress = True
    #     self.ymin = new[0]
    #     self.ymax = new[1]
    #     self._suppress = False

    @on_trait_change('ymin, ymax')
    def _handle_ymin_max(self, name, new):
        if self._suppress:
            return

        self._has_ylimits = True
        self.ylimits = (self.ymin, self.ymax)

    def set_overlay_position(self, k, v):
        self.overlay_positions[k] = v

    def has_xlimits(self):
        return self._has_xlimits or (self.xlimits is not None
                                     and self.xlimits[0] != self.xlimits[1])

    def has_ylimits(self):
        return self._has_ylimits or (self.ylimits is not None
                                     and self.ylimits[0] != self.ylimits[1])

    def clear_ylimits(self):
        self._has_ylimits = False
        self.ylimits = (0, 0)

    def clear_xlimits(self):
        self._has_xlimits = False
        self.xlimits = (0, 0)

    def _name_changed(self):
        # if self.initialized:
        if self.name and self.name != NULL_STR:
            self.plot_enabled = True

    def _get_plot_name(self):

        if self._plot_names and self.name in self.names:
            return self._plot_names[self.names.index(self.name)]
        else:
            return self.name

    def _set_has_filter(self, v):
        pass

    def _get_has_filter(self):
        return self.filter_str or self.sigma_filter_n