Beispiel #1
0
    class EventSignalTest(Atom):

        e = Event()

        s = Signal()

        counter = Int()
Beispiel #2
0
class GraphControllerBase(Atom):

    view = Typed(GraphicsView)

    itemsSelected = Signal()

    def set_view(self, view):
        self.view = view

    def create_node(self, typename, id=None, **kw):
        raise NotImplementedError

    def destroy_node(self, id):
        raise NotImplementedError

    def create_edge(self, typename, id=None, **kw):
        raise NotImplementedError

    def destroy_edge(self, id):
        raise NotImplementedError

    def edge_type_for_start_socket(self, start_node_id, start_socket_id):
        raise NotImplementedError

    def edge_can_connect(self, start_node_id, start_socket_id, end_node_id,
                         end_socket_id):
        return True

    def edge_connected(self, id):
        pass

    def edge_disconnect(self, id):
        pass
Beispiel #3
0
class Model(Atom):
    """
    Read: Просто читаю кадры и кладу в очередь через период (в отдельном потоке).
        Период либо берется из свойств файла, либо константа (10 мс) - если мы не знаем свойств источника.
    Edit: По таймеру, через (другой) период беру последний кадр из очереди и обрабатываю его (в основном потоке).
    Out: После обработки сразу же вывожу (сразу же в том же куске кода).

    Constructor parameters
    ----------------------
    file : Unicode = ''
        Video file path
    device : Int = 0
        OpenCV device ID used only if the file doesn't exist
    """
    file = Unicode('')
    device = Int(0)
    captured_frames = Typed(deque)
    edited_frames = Typed(deque)
    y = d_(Int(0))
    capturing = d_(Bool(True))
    update = Signal()

    def start_capturing(self):
        """
        Starts workers threads.
        """
        self.captured_frames = deque()
        self.edited_frames = deque()
        if p.isfile(self.file):
            capture = cv.VideoCapture(self.file)
            period = 1 / capture.get(5)
        else:
            capture = cv.VideoCapture(self.device)
            period = CAP_PERIOD
        thread_capturer = Thread(target=worker_capturer,
                                 args=(capture, period, self))
        thread_editor = Thread(target=worker_editor, args=(period / 2, self))
        thread_capturer.daemon = True
        thread_editor.daemon = True
        thread_capturer.start()
        thread_editor.start()

    def frame_and_cut(self) -> Tuple[ndarray, ndarray] or None:
        """
        Pops (frame, cut) from a deque if there are more than one tuple inside.

        Returns
        -------
        (frame, cut) or None
        """
        if len(self.edited_frames) > 1:
            frame, cut = self.edited_frames.pop()
            return frame, cut
        else:
            return None
Beispiel #4
0
class PlotSettings(Atom):
    """ The base class for all plot settings instances.

    """
    # The title for the plot
    title = Unicode()

    settings_changed = Signal()

    def _observe_title(self, change):
        if change['type'] == 'update':
            self.settings_changed.emit()

    def copy(self):
        return self.__class__(title=self.title)
class Scene3D(Declarative):

    trigger_update = Signal()

    def child_added(self, child):
        super(Scene3D, self).child_added(child)
        if isinstance(child, SceneGraphNode):
            child.scene_root = self

    def child_removed(self, child):
        super(Scene3D, self).child_removed(child)
        if isinstance(child, SceneGraphNode):
            child.scene_root = None

    @property
    def nodes(self):
        return [c for c in self.children if isinstance(c, SceneGraphNode)]
class OpenGLWidget(Control, MouseHandler, KeyHandler):
    """ An extremely simple widget for displaying OpenGL.

    """

    #: the renderer for the widget
    renderer = d_(Typed(Renderer))

    #: trigger a widget update
    update = d_(Signal())

    mouse_handler = d_(Typed(MouseHandler))
    key_handler = d_(Typed(KeyHandler))

    #: An opengl control expands freely in height and width by default.
    hug_width = set_default('ignore')
    hug_height = set_default('ignore')

    #: A reference to the ProxyOpenGLWidget object
    proxy = Typed(ProxyOpenGLWidget)

    #--------------------------------------------------------------------------
    # Observers
    #--------------------------------------------------------------------------
    @observe(
        'renderer', )
    def _update_proxy(self, change):
        """ An observer which sends state change to the proxy.
        """
        if self.renderer:
            self.unobserve("renderer.trigger_update", self.proxy.update)
        super(OpenGLWidget, self)._update_proxy(change)
        self.observe("renderer.trigger_update", self.proxy.update)

    @observe('update')
    def _update_canvas(self, *args):
        """ An observer which propagates update events to the widget
        """
        self.proxy.update()

    @observe('mouse_handler', 'key_handler')
    def _update_handlers(self, change):
        """ An observer which connects key/mouse handlers
        """
        super(OpenGLWidget, self)._update_proxy(change)
Beispiel #7
0
class HistogramPlotSettings(XYPlotSettings):
    # The method to use for binning the data. See ``auto_histogram()``.
    bin_method = Enum('fd', 'rice', 'sqrt', 'sturges', 'custom')

    # The number of bins to use. Can be set automatically by a heuristic
    # (`bin_method`) or set manually (when `bin_method == 'custom'`).
    bins = Typed(OrderedDict)

    # A signal emitted when a bin number is changed.
    bins_changed = Signal()

    # Whether to show the usual histogram, a smoothed kernel density estimate,
    # and a "rug" plot on the x-axis, respectively.
    show_hist = Bool(True)
    show_kde = Bool(False)
    show_rug = Bool(False)

    def set_bins(self, group_name, bins):
        """ Set the number of bins for a group. This method must be used
        instead of directly modifying the dictionary, as Atom doesn't
        currently support container notifications for dictionaries.
        """
        if self.bins.get(group_name) != bins:
            self.bins[group_name] = bins
            self.bins_changed.emit()

    def trim_bins(self, valid_group_names):
        """ Remove all bin keys that are not in valid_group_names.
        """
        for key in self.bins:
            if key not in valid_group_names:
                del self.bins[key]

    def copy(self):
        copied = super(HistogramPlotSettings, self).copy()

        copied.bin_method = self.bin_method
        copied.bins = self.bins.copy()
        copied.show_hist = self.show_hist
        copied.show_kde = self.show_kde
        copied.show_rug = self.show_rug

        return copied
Beispiel #8
0
class GenericI3pyTask(InstrumentTask):
    """Base class for all tasks calling instruments.

    """
    #: List of instruction the task should perform. This list should not be
    #: manipulated by user code.
    instructions = List()

    #: Signal emitted to notify listener that the instructions list has been
    #: modified.
    instruction_changed = Signal()

    def check(self, *args, **kwargs):
        """Check that all instructions are properly configured.

        """
        test, traceback = super().check(*args, **kwargs)
        if not test:
            return test, traceback

        err_path = self.get_error_path()
        run_time = self.root.run_time
        _, d_id, _, _ = self.selected_instrument

        # This is safe since the InstrumentTask checks passed
        d_cls, _ = run_time[DRIVER_DEPENDENCY_ID][d_id]

        for instr in self.instructions:
            test, value_or_error = instr.check(self, d_cls)
            if test:
                for entry_id, value in value_or_error.items():
                    self.write_in_database(entry_id, value)
            else:
                traceback[err_path + '-' + instr.id] = value_or_error

    def perform(self):
        """Call all instructions in order.

        """
        for i in self.instructions:
            i.execute(self, self.driver)

    def add_instruction(self, instruction, index):
        """Add an instruction at the given index.

        Parameters
        ----------
        index : int
            Index at which to insert the new child task.

        instruction : BaseInstruction
            Instruction to insert in the list of instructions.

        """
        self.instructions.insert(index, instruction)

        # In the absence of a root task do nothing else than inserting the
        # child.
        if self.has_root:

            # Register the new entries in the database
            db_entries = self.database_entries.copy()
            db_entries.update(instruction.database_entries)
            self.database_entries = db_entries
            instruction.observe('database_entries',
                                self._react_to_instr_database_entries_change)

            # Register anew preferences to keep the right ordering for the
            # instructions
            self.register_preferences()

            change = ContainerChange(obj=self,
                                     name='instructions',
                                     added=[(index, instruction)])
            self.instruction_changed(change)

    def remove_instruction(self, index):
        """Remove an instruction from the instructions list.

        Parameters
        ----------
        index : int
            Index at which the instruction to remove is located.

        """
        instruction = self.instructions.pop(index)

        # Cleanup database
        db_entries = self.database_entries.copy()
        for k in db_entries:
            if k in instruction.database_entries:
                del db_entries[k]
        self.database_entries = db_entries
        instruction.unobserve('database_entries',
                              self._react_to_instr_database_entries_change)

        # Update preferences
        self.register_preferences()

        change = ContainerChange(obj=self,
                                 name='instructions',
                                 removed=[(index, instruction)])
        self.instruction_changed(change)

    def move_instruction(self, old, new):
        """Move an instruction.

        Parameters
        ----------
        old : int
            Index at which the instruction to move is currently located.

        new : BaseTask
            Index at which to insert the instruction.

        """
        instruction = self.instructions.pop(old)
        self.instructions.insert(new, instruction)

        # In the absence of a root task do nothing else than moving the
        # child.
        if self.has_root:
            # Register anew preferences to keep the right ordering for the
            # children
            self.register_preferences()

            change = ContainerChange(obj=self,
                                     name='instructions',
                                     moved=[(old, new, instruction)])
            self.instruction_changed(change)

    def register_preferences(self):
        """Create the task entries in the preferences object.

        """
        super(GenericI3pyTask, self).register_preferences()

        # Register the instructions
        for i, instr in enumerate(self.instructions):
            child_id = 'instruction_{}'.format(i)
            self.preferences[child_id] = instr.preferences_from_members()

    @classmethod
    def build_from_config(cls, config, dependencies):
        """Create a new instance using the provided infos for initialisation.

        Parameters
        ----------
        config : dict(str)
            Dictionary holding the new values to give to the members in string
            format, or dictionnary like for instance with prefs.

        dependencies : dict
            Dictionary holding the necessary classes needed when rebuilding..

        """
        task = cls()
        update_members_from_preferences(task, config)

        # Collect and build the instructions
        i = 0
        pref = 'instruction_{}'
        instructions = []
        while True:
            instr_name = pref.format(i)
            if instr_name not in config:
                break
            instr_config = config[instr_name]
            instr_class_name = instr_config.pop('instruction_id')
            instr_cls = dependencies[DEP_TYPE][instr_class_name]
            instr = instr_cls.build_from_config(instr_config, dependencies)
            instructions.append(instr)
            i += 1

        task.instructions = instructions

        return task

    def traverse(self, depth=-1):
        """Yield a task and all of its components.

        The base implementation simply yields the task itself.

        Parameters
        ----------
        depth : int
            How deep should we explore the tree of tasks. When this number
            reaches zero deeper children should not be explored but simply
            yielded.

        """
        yield self
        for instr in self.instructions:
            yield instr

    # =========================================================================
    # --- Private API ---------------------------------------------------------
    # =========================================================================

    def _react_to_instr_database_entries_change(self, change):
        """Update the database entries whenever an instruction modify its used
        names.

        """
        db_entries = self.database_entries.copy()
        if 'old_value' in change:
            for k in db_entries:
                if k in change['object'].database_entries:
                    del db_entries[k]
        db_entries.update(change['value'])
        self.database_entries = db_entries
Beispiel #9
0
class EditorModel(Atom):
    """Model driving the database access editor.

    """
    #: Reference to the root task of the currently edited task hierarchy.
    root = Typed(RootTask)

    #: Signal that a node was deleted (the payload is the node model object).
    node_deleted = Signal()

    #: Dictionary storing the nodes for all tasks by path.
    nodes = Dict()

    def increase_exc_level(self, path, entry):
        """Increase the exception level of an access exception.

        Parameters
        ----------
        path : unicode
            Path of the node in which the exception to increase is.

        entry : unicode
            Entry whose access exception should be increased.

        """
        self._modify_exception_level(path, entry, 1)

    def decrease_exc_level(self, path, entry):
        """Decrease the exception level of an access exception.

        Parameters
        ----------
        path : unicode
            Path of the node in which the exception to increase is.

        entry : unicode
            Entry whose access exception should be increased.

        """
        self._modify_exception_level(path, entry, -1)

    # =========================================================================
    # --- Private API ---------------------------------------------------------
    # =========================================================================

    def _modify_exception_level(self, path, entry, val):
        """Modify the exception level of an access exception.

        Parameters
        ----------
        path : unicode
            Path of the node in which the exception to increase is.

        entry : unicode
            Entry whose access exception should be increased.

        val : int
            Amount by which to modify the level.

        """
        database_node = self.root.database.go_to_path(path)
        real_path = path + '/' + database_node.meta['access'][entry]
        task, entry = self.nodes[real_path]._find_task_from_entry(entry)
        level = task.access_exs[entry]
        task.modify_access_exception(entry, level + val)

    def _post_setattr_root(self, old, new):
        """Ensure we are observing the right database.

        """
        if old:
            old.database.unobserve('notifier', self._react_to_entries)
            old.database.unobserve('access_notifier',
                                   self._react_to_exceptions)
            old.database.unobserve('nodes_notifier', self._react_to_nodes)

        if new:
            new.database.observe('notifier', self._react_to_entries)
            new.database.observe('access_notifier', self._react_to_exceptions)
            new.database.observe('nodes_notifier', self._react_to_nodes)

            database_nodes = new.database.list_nodes()
            nodes = {
                p: self._model_from_node(p, n)
                for p, n in database_nodes.items()
            }
            for p, m in nodes.items():
                if '/' in p:
                    p, _ = p.rsplit('/', 1)
                    m.parent = nodes[p]
                    nodes[p].children.append(m)

            for nmodel in nodes.values():
                nmodel.sort_nodes()

            self.nodes = nodes

    def _react_to_entries(self, news):
        """Handle modification to entries.

        """
        if isinstance(news, list):
            for n in news:
                self._react_to_entries(n)
            return

        path, entry = news[1].rsplit('/', 1)
        n = self.nodes[path]
        if news[0] == 'added':
            n.entries = n.entries[:] + [entry]

        elif news[0] == 'renamed':
            entries = n.entries[:]
            del entries[entries.index(entry)]
            entries.append(news[2].rsplit('/', 1)[1])
            n.entries = entries

        elif news[0] == 'removed':
            entries = n.entries[:]
            del entries[entries.index(entry)]
            n.entries = entries

    def _react_to_exceptions(self, news):
        """Handle modifications to the access exceptions.

        """
        if isinstance(news, list):
            for n in news:
                self._react_to_exceptions(n)
            return

        path = news[1]
        n = self.nodes[path]
        origin_node = self.nodes[path + '/' + news[2] if news[2] else path]
        if news[0] == 'added':
            n.exceptions = n.exceptions[:] + [news[3]]

            origin_node.has_exceptions = n.has_exceptions[:] + [news[3]]

        elif news[0] == 'renamed':
            exceptions = n.exceptions[:]
            del exceptions[exceptions.index(news[3])]
            exceptions.append(news[4])
            n.exceptions = exceptions

            exs = origin_node.has_exceptions[:]
            del exs[exs.index(news[3])]
            exs.append(news[4])
            origin_node.has_exceptions = exs

        elif news[0] == 'removed':
            exceptions = n.exceptions[:]
            if news[3]:
                del exceptions[exceptions.index(news[3])]
                n.exceptions = exceptions

                exs = origin_node.has_exceptions[:]
                del exs[exs.index(news[3])]
                origin_node.has_exceptions = exs
            else:
                n.exceptions = []
                origin_node.has_exceptions = []

    def _react_to_nodes(self, news):
        """Handle modifications of the database nodes.

        """
        if isinstance(news, list):
            for n in news:
                self._react_to_nodes(n)
            return

        path = news[1] + '/' + news[2]
        if news[0] == 'added':
            parent = self.nodes[news[1]]
            model = self._model_from_node(path, news[3])
            model.parent = parent
            parent.children.append(model)
            parent.sort_nodes()
            self.nodes[path] = model

        elif news[0] == 'renamed':
            new_path = news[1] + '/' + news[3]
            nodes = self.nodes.copy()
            for k, v in nodes.items():
                if k.startswith(path):
                    del self.nodes[k]
                    self.nodes[new_path + k[len(path):]] = v

        elif news[0] == 'removed':
            node = self.nodes[path]
            del self.nodes[path]
            parent = node.parent
            del parent.children[parent.children.index(node)]
            parent.sort_nodes()
            self.node_deleted(node)

    def _get_task(self, path):
        """Retrieve the task corresponding to a certain path.

        """
        if '/' not in path:
            return self.root

        names = path.split('/')[1:]
        task = self.root
        for n in names:
            for t in task.gather_children() + [None]:
                if t is None:
                    raise ValueError('No task matching the specified path')
                if t.name == n:
                    task = t
                    break

        return task

    def _model_from_node(self, path, node):
        """Build a new model from a node informations.

        """
        entries = [
            k for k, v in node.data.items() if not isinstance(v, DatabaseNode)
        ]
        excs = list(node.meta.get('access', {}).keys())
        return NodeModel(editor=self,
                         entries=entries,
                         exceptions=excs,
                         task=self._get_task(path))
Beispiel #10
0
    class A(Atom):

        m = member

    assert A.m.delattr_mode[0] == DelAttr.NoOp
    a = A()
    a.m = 1
    del a.m
    assert a.m == 1
    assert A.m.do_delattr(a) is None
    assert a.m == 1


@pytest.mark.parametrize("member, mode",
                         [(Event(), DelAttr.Event), (Signal(), DelAttr.Signal),
                          (ReadOnly(), DelAttr.ReadOnly),
                          (Constant(1), DelAttr.Constant)
                          ])
def test_undeletable(member, mode):
    """Test that unsettable members do raise the proper error.

    """
    class Undeletable(Atom):

        m = member

    assert Undeletable.m.delattr_mode[0] == mode
    u = Undeletable()
    with pytest.raises(TypeError):
        del u.m
Beispiel #11
0
    signal_handler: siganls are not settable
    delegate_handler: not tested here (see test_delegate.py)
    property_handler: not tested here (see test_property.py)
    call_object_object_value_handler: use a custom function
    call_object_object_name_value_handler: use a custom function
    object_method_value_handler: use an object method
    object_method_name_value_handler: use an object method
    member_method_object_value_handler: method defined on a Member subclass

"""
import pytest

from atom.api import Atom, Constant, Int, ReadOnly, SetAttr, Signal


@pytest.mark.parametrize("member, mode", [(Signal(), "Signal"),
                                          (Constant(1), "Constant")])
def test_unsettable(member, mode):
    """Test that unsettable members do raise the proper error."""
    class Unsettable(Atom):

        m = member

    u = Unsettable()
    assert u.get_member("m").setattr_mode[0] == getattr(SetAttr, mode)
    with pytest.raises(TypeError) as excinfo:
        u.m = 1
    assert mode.lower() in excinfo.exconly()


@pytest.mark.parametrize("member, mode", [(Int(), "Slot"),
Beispiel #12
0
class Renderer(Declarative):

    #: items
    scene = d_(Typed(Scene3D))

    #: the canvas size as reported by resizeGL
    canvas_size = d_(Typed(Size))

    #: background color
    background_color = d_(Typed(np.ndarray))

    #: trigger an update to
    trigger_update = Signal()

    #: Cyclic notification guard flags.
    _guard = d_(Int(0))

    def initialize_gl(self, widget):
        for item in self.scene.nodes:
            item.initialize_gl()
        self.scene.observe("trigger_update", self.check_trigger_update)

    def enable_trigger(self, value):
        if value:
            self._guard &= ~RENDERING_FLAG
        else:
            self._guard |= RENDERING_FLAG

    def check_trigger_update(self):
        if self._guard & RENDERING_FLAG:
            return
        self.trigger_update()

    def resize_gl(self, widget, size):
        self.canvas_size = size

    def paint_gl(self, widget):
        if self._guard & RENDERING_FLAG:
            return
        self._guard |= RENDERING_FLAG

        try:
            self.clear_screen()
            self.render(widget)
            # swap buffers manually ?
        except Exception as e:
            log.exception(e)

        self._guard &= ~RENDERING_FLAG

    def clear_screen(self):
        if self.background_color is not None:
            glClearColor(*self.background_color.flatten())
        glClear( GL_DEPTH_BUFFER_BIT | GL_COLOR_BUFFER_BIT )

    def render_items(self, context):
        for item in self.scene.nodes:
            item.render(context.copy())

    # overwrite in renderer implementations
    def render(self, widget):
        raise NotImplementedError
Beispiel #13
0
class BaseEngine(Atom):
    """ Base class for all engines.

    An engine is responsible for performing a measurement given a hierarchical
    ensemble of tasks.

    """

    #: Declaration defining this engine.
    declaration = ForwardTyped(lambda: Engine)

    #: Signal used to pass news about the measurement progress.
    news = Signal()

    #: Event through which the engine signals it is done with a measure.
    done = Event()

    #: Bool representing the current state of the engine.
    active = Bool()

    #: Tuple representing the status of the running measure of the engine.
    #: This must a length 2 tuple which the plugin will map to the status and
    #: infos of the measure being processed.
    measure_status = Tuple()

    def prepare_to_run(self, name, root, monitored_entries, build_deps):
        """ Make the engine ready to perform a measure.

        This method does not start the engine.

        Parameters
        ----------
        name : str
            Name of the measure.

        root : RootTask
            The root task representing the measure to perform.

        monitored_entries : iterable
            The database entries to observe. Any change of one of these entries
            will be notified by the news event.

        build_deps : dict
            Dict holding the build dependencies of the task.

        """
        mes = cleandoc('''''')
        raise NotImplementedError(mes)

    def run(self):
        """ Start the execution of the measure by the engine.

        This method must not wait for the measure to complete to return.

        """
        mes = cleandoc('''''')
        raise NotImplementedError(mes)

    def pause(self):
        """ Ask the engine to pause the current measure.

        This method should not wait for the measure to pause to return.
        When the pause is effective the engine should add pause to the plugin
        flags.

        """
        mes = cleandoc('''''')
        raise NotImplementedError(mes)

    def resume(self):
        """ Ask the engine to resume the currently paused measure.

        This method should not wait for the measure to resume.
        Thsi method should remove the 'paused' flag from the plugin flags.

        """
        mes = cleandoc('''''')
        raise NotImplementedError(mes)

    def stop(self):
        """ Ask the engine to stop the current measure.

        This method should not wait for the measure to stop.

        """
        mes = cleandoc('''''')
        raise NotImplementedError(mes)

    def exit(self):
        """ Ask the engine top stop completely.

        After a call to this method the engine may need to re-initialize a
        number of things before running the next measure. This method should
        not wait for the engine to exit.

        """
        mes = cleandoc('''''')
        raise NotImplementedError(mes)

    def force_stop(self):
        """ Force the engine to stop the current measure.

        This method should stop the measure no matter what is going on. It can
        block.

        """
        mes = cleandoc('''''')
        raise NotImplementedError(mes)

    def force_exit(self):
        """ Force the engine to exit.

        This method should stop the process no matter what is going on. It can
        block.

        """
        mes = cleandoc('''''')
        raise NotImplementedError(mes)
class QtListStrWidget(RawWidget):
    """ A Qt4 implementation of an Enaml ProxyListStrView.

    """

    __slots__ = '__weakref__'

    #: The list of str being viewed
    items = d_(List(Unicode()))

    checked_states = d_(ContainerList(Bool()))

    #: The index of the currently selected str
    selected_index = d_(Int(-1))

    #: The currently selected str
    selected_item = d_(Unicode())

    #: Whether or not the items should be checkable
    checkable = d_(Bool(True))

    #: Whether or not the items should be editable
    editable = d_(Bool(True))

    #
    validator = d_(Callable())

    #: .
    hug_width = set_default('weak')

    item_changed = Signal()
    enable_changed = Signal()

    #--------------------------------------------------------------------------
    # Initialization API
    #--------------------------------------------------------------------------
    def create_widget(self, parent):
        """ Create the QListWidget widget.

        """
        # Create the list model and accompanying controls:
        widget = QListWidget(parent)
        for item, checked in zip(self.items, self.checked_states):
            self.add_item(widget, item, checked)

        # set selected_item here so that first change fires an 'update' rather than 'create' event
        self.selected_item = ''
        if self.items:
            self.selected_index = 0
            self.selected_item = self.items[0]
            widget.setCurrentRow(0)

        widget.itemSelectionChanged.connect(self.on_selection)
        widget.itemChanged.connect(self.on_edit)

        return widget

    def add_item(self, widget, item, checked=True):
        itemWidget = QListWidgetItem(item)
        if self.checkable:
            itemWidget.setCheckState(Qt.Checked if checked else Qt.Unchecked)
        if self.editable:
            _set_item_flag(itemWidget, Qt.ItemIsEditable, True)
        widget.addItem(itemWidget)
        self.apply_validator(itemWidget, itemWidget.text())

    #--------------------------------------------------------------------------
    # Signal Handlers
    #--------------------------------------------------------------------------
    def on_selection(self):
        """ 
        The signal handler for the index changed signal.
        """
        widget = self.get_widget()
        self.selected_index = widget.currentRow()
        self.selected_item = self.items[
            widget.currentRow()] if self.selected_index >= 0 else u''

    def on_edit(self, item):
        """ 
        The signal handler for the item changed signal.
        """
        widget = self.get_widget()
        itemRow = widget.indexFromItem(item).row()
        oldLabel = self.items[itemRow]
        newLabel = item.text()
        if oldLabel != newLabel:
            self.item_changed(oldLabel, newLabel)
            self.selected_item = item.text()
            self.items[itemRow] = item.text()
            self.apply_validator(item, newLabel)
        else:
            self.checked_states[itemRow] = True if item.checkState(
            ) == Qt.Checked else False
            self.enable_changed(item.text(), self.checked_states[itemRow])

    #--------------------------------------------------------------------------
    # ProxyListStrView API
    #--------------------------------------------------------------------------

    def set_items(self, items, widget=None):
        """
        """
        widget = self.get_widget()
        count = widget.count()
        nitems = len(items)
        for idx, item in enumerate(items[:count]):
            itemWidget = widget.item(idx)
            #Update checked state before the text so that we can distinguish a checked state change from a label change
            itemWidget.setCheckState(
                Qt.Checked if self.checked_states[idx] else Qt.Unchecked)
            itemWidget.setText(item)
            self.apply_validator(itemWidget, item)
        if nitems > count:
            for item in items[count:]:
                self.add_item(widget, item)
        elif nitems < count:
            for idx in reversed(xrange(nitems, count)):
                widget.takeItem(idx)

    #--------------------------------------------------------------------------
    # Utility methods
    #--------------------------------------------------------------------------

    def apply_validator(self, item, label):
        if self.validator and not self.validator(label):
            item.setTextColor(QColor(255, 0, 0))
        else:
            item.setTextColor(QColor(0, 0, 0))

    #--------------------------------------------------------------------------
    # Observers
    #--------------------------------------------------------------------------
    @observe('items')
    def _update_proxy(self, change):
        """ An observer which sends state change to the proxy.

        """
        # The superclass handler implementation is sufficient.
        if self.get_widget() != None:
            if change["name"] == "items":
                if change["type"] == "update":
                    if len(change["oldvalue"]) > len(change["value"]):
                        #We've lost an item
                        removedKey = set(change["oldvalue"]) - set(
                            change["value"])
                        removedIndex = change["oldvalue"].index(
                            list(removedKey)[0])
                        del self.checked_states[removedIndex]
                    elif len(change["oldvalue"]) < len(change["value"]):
                        self.checked_states.append(True)

            self.set_items(self.items)
Beispiel #15
0
class QtImageAnalysis(QtControl, ProxyImageAnalysis):
    #: A reference to the widget created by the proxy.
    widget = Typed(pg.GraphicsLayoutWidget)

    server_id = Str()
    arrays_model = Instance(Atom)

    #
    # Different controls that can be displayed on the GUI
    #
    almucantar_scatter = Instance(pg.ScatterPlotItem)
    epipolar_scatter = Instance(pg.ScatterPlotItem)
    grid_scatter = Instance(pg.ScatterPlotItem)
    principalplane_scatter = Instance(pg.ScatterPlotItem)
    plot_area = Instance(pg.PlotItem)

    #
    # ROI - Rectangle ROI that can be set by the user.
    #
    ROI = Instance(pg.RectROI)
    ROI_signal = Signal()

    #
    # mask_ROI - Polygon ROI that can be use to mask buildings and other
    # obstacles.
    #
    mask_ROI = Instance(pg.PolyLineROI)

    #
    # Sun - Mark the expected position of the sun (Calculated from time data.)
    #
    Sun = Instance(pg.ScatterPlotItem)

    #
    # Signals to notify the main model of modifications
    # that need to be broadcast to the rest of the cameras.
    #
    LOS_signal = Signal()

    #
    # The image itself, a PyQtGraph ImageItem.
    #
    img_item = Instance(pg.ImageItem)

    #
    # For internal use.
    # Note:
    # These are used for avoiding double updates of the ROI and mask_ROI
    # states. The updates are caused when declaration.xxx is set (to update
    # view) which calls the 'set_' callback. These double update cause an
    # exception in the pyqtgraph (don't know why).
    #
    _internal_ROI_update = Bool(False)
    _internal_mask_ROI_update = Bool(False)

    def initEpiploarPoints(self, epipolar_coords):
        """Initialize the epipolar points on the plot."""

        self.epipolar_scatter = pg.ScatterPlotItem(size=5,
                                                   pen=pg.mkPen(None),
                                                   brush=pg.mkBrush(
                                                       255, 255, 0, 120))
        self.epipolar_scatter.addPoints(pos=np.array(epipolar_coords).T)

        self.epipolar_scatter.setZValue(100)

        self.plot_area.addItem(self.epipolar_scatter)

    def initGrid(self, grid_coords):
        """Initialize the grid (scatter points) on the plot."""

        with open("grid.pkl", "wb") as f:
            pickle.dump(grid_coords, f)

        self.grid_scatter = pg.ScatterPlotItem(size=3,
                                               pen=pg.mkPen(None),
                                               brush=pg.mkBrush(
                                                   255, 0, 0, 255))

        xs, ys, mask = grid_coords
        self.grid_scatter.addPoints(pos=np.array((xs[mask], ys[mask])).T)
        self.grid_scatter.setZValue(100)

        self.plot_area.addItem(self.grid_scatter)
        self.grid_scatter.setVisible(False)

    def initAlmucantar(self, almucantar_coords):
        """Initialize the Almucantar marker"""

        self.almucantar_scatter = pg.ScatterPlotItem(size=3,
                                                     pen=pg.mkPen(None),
                                                     brush=pg.mkBrush(
                                                         255, 0, 0, 120))
        self.almucantar_scatter.addPoints(pos=np.array(almucantar_coords))
        self.almucantar_scatter.setZValue(99)
        self.almucantar_scatter.setVisible(False)

        self.plot_area.addItem(self.almucantar_scatter)

    def initPrincipalPlane(self, PrincipalPlane_coords):
        """Initialize the Principal Plane marker"""

        self.principalplane_scatter = pg.ScatterPlotItem(size=3,
                                                         pen=pg.mkPen(None),
                                                         brush=pg.mkBrush(
                                                             255, 0, 0, 120))
        self.principalplane_scatter.addPoints(
            pos=np.array(PrincipalPlane_coords))
        self.principalplane_scatter.setZValue(98)
        self.principalplane_scatter.setVisible(False)

        self.plot_area.addItem(self.principalplane_scatter)

    def initSun(self, Sun_coords):
        """Initialize the drawing of the Sun"""

        self.Sun = pg.ScatterPlotItem(size=5,
                                      pen=pg.mkPen(None),
                                      brush=pg.mkBrush(255, 255, 0, 120))
        self.Sun.addPoints(pos=np.array(Sun_coords))
        self.Sun.setZValue(98)
        self.Sun.setVisible(False)

        self.plot_area.addItem(self.Sun)

    def initROIs(self, img_shape):
        """Initialize the ROI markers"""

        #
        # Mask ROI
        #
        angles = np.linspace(0, 2 * np.pi, MASK_INIT_RESOLUTION)
        xs = img_shape[0] * (1 + 0.9 * np.cos(angles)) / 2
        ys = img_shape[1] * (1 + 0.9 * np.sin(angles)) / 2
        mask_positions = np.vstack((xs, ys)).T
        self.mask_ROI = pg.PolyLineROI(mask_positions,
                                       closed=True,
                                       pen=dict(color=(255, 0, 0), width=2))
        self.mask_ROI.setVisible(False)

        self.plot_area.vb.addItem(self.mask_ROI)

        #
        # Reconstruction ROI
        #
        self.ROI = pg.RectROI([20, 20], [20, 20], pen=(0, 9))
        self.ROI.addRotateHandle([1, 0], [0.5, 0.5])
        self.ROI.setVisible(False)

        #
        # Callback when the user stops moving a ROI.
        #
        self.ROI.sigRegionChangeFinished.connect(self._ROI_updated)
        self.mask_ROI.sigRegionChangeFinished.connect(self._mask_ROI_updated)

        self.plot_area.vb.addItem(self.ROI)

    def mouseClicked(self, evt):
        """Callback of mouse click (used for updating the epipolar lines)."""
        #
        # Get the click position.
        #
        pos = evt.scenePos()

        if self.plot_area.sceneBoundingRect().contains(pos):
            #
            # Map the click to the image.
            #
            mp = self.plot_area.vb.mapSceneToView(pos)
            h, w = self.img_item.image.shape[:2]
            x, y = np.clip((mp.x(), mp.y()), 0, h - 1).astype(np.int)

            #
            # Update the LOS points.
            #
            self.LOS_signal.emit({'server_id': self.server_id, 'pos': (x, y)})

    def create_widget(self):
        """Create the PyQtGraph widget"""

        self.widget = pg.GraphicsLayoutWidget(self.parent_widget())

        self.plot_area = self.widget.addPlot()
        self.plot_area.hideAxis('bottom')
        self.plot_area.hideAxis('left')

        #
        # Connect the click callback to the plot.
        #
        self.plot_area.scene().sigMouseClicked.connect(self.mouseClicked)

        #
        # Item for displaying image data
        #
        self.img_item = pg.ImageItem()
        self.plot_area.addItem(self.img_item)

        self.img_item.setImage(np.zeros(DEFAULT_IMG_SHAPE))

        #
        # Setup the ROIs
        #
        self.initROIs(DEFAULT_IMG_SHAPE)

        self.widget.resize(400, 400)

        return self.widget

    def init_widget(self):
        """ Initialize the widget.

        """
        super(QtImageAnalysis, self).init_widget()
        d = self.declaration
        self.set_server_id(d.server_id)
        self.set_arrays_model(d.arrays_model)
        self.set_img_array(d.img_array)
        self.set_Almucantar_coords(d.Almucantar_coords)
        self.set_PrincipalPlane_coords(d.PrincipalPlane_coords)
        self.set_Sun_coords(d.Sun_coords)
        self.set_Epipolar_coords(d.Epipolar_coords)
        self.set_GRID_coords(d.GRID_coords)
        self.set_show_almucantar(d.show_almucantar)
        self.set_show_principalplane(d.show_principalplane)
        self.set_show_grid(d.show_grid)
        self.set_show_mask(d.show_mask)
        self.set_show_ROI(d.show_ROI)
        self.set_show_sun(d.show_sun)
        self.set_gamma(d.gamma)
        self.set_intensity(d.intensity)
        self.set_ROI_state(d.ROI_state)
        self.set_mask_ROI_state(d.mask_ROI_state)
        self.set_mask(d.mask)
        self.observe('LOS_signal', self.arrays_model.updateLOS)

    def set_server_id(self, server_id):

        self.server_id = server_id

    def set_arrays_model(self, arrays_model):

        self.arrays_model = arrays_model

    def set_img_array(self, img_array):
        """Update the image array."""

        self.img_item.setImage(img_array.astype(np.float))

    def set_Almucantar_coords(self, Almucantar_coords):
        """Update the Almucantar coords."""

        if self.almucantar_scatter is None:
            self.initAlmucantar(Almucantar_coords)
            return

        self.almucantar_scatter.setData(pos=np.array(Almucantar_coords))

    def set_PrincipalPlane_coords(self, PrincipalPlane_coords):
        """Update the Almucantar coords."""

        if self.principalplane_scatter is None:
            self.initPrincipalPlane(PrincipalPlane_coords)
            return

        self.principalplane_scatter.setData(
            pos=np.array(PrincipalPlane_coords))

    def set_Epipolar_coords(self, Epipolar_coords):
        """Update the Epipolar coords."""

        if self.epipolar_scatter is None:
            self.initEpiploarPoints(Epipolar_coords)
            return

        self.epipolar_scatter.setData(pos=np.array(Epipolar_coords).T)

    def set_Sun_coords(self, Sun_coords):
        """Update the Sun coords."""

        if self.Sun is None:
            logging.debug("Initializing the sun position: {}.".format(
                np.array(Sun_coords)))
            self.initSun(Sun_coords)
            return

        logging.debug("Updating the sun position: {}.".format(
            np.array(Sun_coords)))
        self.Sun.setData(pos=np.array(Sun_coords))

    def set_GRID_coords(self, GRID_coords):
        """Update the grid coords."""

        if self.grid_scatter is None:
            self.initGrid(GRID_coords)
            return

        xs, ys, mask = GRID_coords
        self.grid_scatter.setData(pos=np.array((xs[mask], ys[mask])).T)

    def set_show_almucantar(self, show):
        """Control the visibility of the Almucantar widget."""

        if self.almucantar_scatter is None:
            return

        self.almucantar_scatter.setVisible(show)

    def set_show_principalplane(self, show):
        """Control the visibility of the PrincipalPlane widget."""

        if self.principalplane_scatter is None:
            return

        self.principalplane_scatter.setVisible(show)

    def set_show_grid(self, show):
        """Control the visibility of the grid widget."""

        if self.grid_scatter is None:
            return

        self.grid_scatter.setVisible(show)

    def set_show_mask(self, show):
        """Control the visibility of the mask ROI widget."""

        self.mask_ROI.setVisible(show)

    def set_show_ROI(self, show):
        """Control the visibility of the ROI widget."""

        self.ROI.setVisible(show)

    def set_show_sun(self, show):
        """Control the visibility of the Sun widget."""

        self.Sun.setVisible(show)

    def set_gamma(self, apply_flag):
        """Apply Gamma correction."""

        if apply_flag:
            lut = np.array([((i / 255.0)**0.4) * 255
                            for i in np.arange(0, 256)]).astype(np.uint8)
        else:
            lut = np.arange(0, 256).astype(np.uint8)

        self.img_item.setLookupTable(lut)

    def set_intensity(self, intensity):
        """Set the intensity of the image."""

        self.img_item.setLevels((0, intensity))

    def set_ROI_state(self, state):
        """Set the ROI.

        This is called for example when loading saved ROI.
        """

        if state == {}:
            return

        if self._internal_ROI_update:
            #
            # These are used for avoiding double updates of the ROI and mask_ROI
            # states. The updates are caused when declaration.xxx is set (to update
            # view) which calls the 'set_' callback. These double update cause an
            # exception in the pyqtgraph (don't know why).
            #
            self._internal_ROI_update = False
            return

        self.ROI.sigRegionChangeFinished.disconnect(self._ROI_updated)
        self.ROI.setState(state)
        self.ROI.sigRegionChangeFinished.connect(self._ROI_updated)

    def set_mask(self, mask):
        """Set the manual mask."""
        #
        # The mask should be calculated internally.
        #
        pass

    def set_mask_ROI_state(self, state):
        """Set the mask ROI.

        This is called for example when loading saved ROI.
        """

        if state == {}:
            return

        if self._internal_mask_ROI_update:
            #
            # These are used for avoiding double updates of the ROI and mask_ROI
            # states. The updates are caused when declaration.xxx is set (to update
            # view) which calls the 'set_' callback. These double update cause an
            # exception in the pyqtgraph (don't know why).
            #
            self._internal_mask_ROI_update = False
            return

        self.mask_ROI.sigRegionChangeFinished.disconnect(
            self._mask_ROI_updated)
        self.mask_ROI.setState(state)
        self.mask_ROI.sigRegionChangeFinished.connect(self._mask_ROI_updated)

        self._update_mask()

    def _ROI_updated(self, *args):
        """Callback of ROI udpate.

        This is called when the user stops moving the ROI controls.
        """

        #
        # Propagate the state of the ROI (used for saving from the main GUI).
        #
        self._internal_ROI_update = True
        self.declaration.ROI_state = self.ROI.saveState()

        #
        #
        # Signal change of ROI (used by the map3d)
        #
        # Calculate the bounds.
        #
        _, tr = self.ROI.getArraySlice(self.img_item.image, self.img_item)
        size = self.ROI.state['size']
        pts = np.array(
            [tr.map(x, y) for x, y in \
             ((0, 0), (size.x(), 0), (0, size.y()), (size.x(), size.y()))]
        )

        self.ROI_signal.emit({
            'server_id': self.server_id,
            'pts': pts,
            'shape': self.img_item.image.shape
        })

    def _mask_ROI_updated(self, *args):
        """Callback of mask ROI udpate.

        This is called when the user stops moving the mask ROI controls.
        """

        #
        # Propagate the state of the mask_ROI (used for saving from the main GUI).
        #
        self._internal_mask_ROI_update = True
        self.declaration.mask_ROI_state = self.mask_ROI.saveState()

        self._update_mask()

    def _update_mask(self):
        """Update the mask itself."""

        #
        # Get mask ROI region.
        #
        data = np.ones(self.img_item.image.shape[:2], np.uint8)
        sl, _ = self.mask_ROI.getArraySlice(data, self.img_item, axes=(0, 1))
        sl_mask = self.mask_ROI.getArrayRegion(data, self.img_item)

        #
        # The new version of pyqtgraph has some rounding problems.
        # Fix it the slices accordingly.
        #
        fixed_slices = (slice(sl[0].start, sl[0].start + sl_mask.shape[0]),
                        slice(sl[1].start, sl[1].start + sl_mask.shape[1]))

        mask = np.zeros(self.img_item.image.shape[:2], np.uint8)
        try:
            mask[fixed_slices] = sl_mask
        except ValueError, e:
            #
            # When loading old ROIS their shape might not fit the
            # resolution of the current images.
            #
            pass

        #
        # Propagate the mask.
        #
        self.declaration.mask = mask
Beispiel #16
0
class AbstractSequence(Item):
    """ Base class for all sequences.

    This class defines the basic of a sequence but with only a very limited
    child support : only construction is supported, indexing is not handled
    nor is child insertion, deletion or displacement (This is because
    TemplateSequence inherits from AbstractSequence, while everything else
    inherits from BaseSequence which supports insertion/deletion/displacement).

    """

    #: Name of the sequence (help make a sequence more readable)
    name = Unicode().tag(pref=True)

    #: List of items this sequence consists of.
    items = List(Instance(Item)).tag(child=100)

    #: Signal emitted when the list of items in this sequence changes. The
    #: payload will be a ContainerChange instance.
    items_changed = Signal().tag(child_notifier='items')

    #: Dict of variables whose scope is limited to the sequence. Each key/value
    #: pair represents the name and definition of the variable.
    local_vars = Typed(OrderedDict, ()).tag(pref=(ordered_dict_to_pref,
                                                  ordered_dict_from_pref))

    #: String representing the item first element of definition : according
    #: to the selected mode it evaluated value will either be used for the
    #: start instant, or duration of the item.
    def_1 = Unicode().tag(pref=True, feval=SkipEmpty(types=Real))

    #: String representing the item second element of definition : according
    #: to the selected mode it evaluated value will either be used for the
    #: duration, or stop instant of the item.
    def_2 = Unicode().tag(pref=True, feval=SkipEmpty(types=Real))

    def clean_cached_values(self):
        """ Clear all internal caches.

        This should be called before evaluating a sequence.

        """
        super(AbstractSequence, self).clean_cached_values()
        self._evaluated = []
        for i in self.items:
            i.clean_cached_values()

    def evaluate_sequence(self, root_vars, sequence_locals, missing, errors):
        """Evaluate the sequence vars and all the underlying items.

        Parameters
        ----------
        root_vars : dict
            Dictionary of global variables for the all items. This will
            tipically contains the i_start/stop/duration and the root vars.
            This dict must be updated with global new values but for
            evaluation sequence_locals must be used.

        sequence_locals : dict
            Dictionary of variables whose scope is limited to this sequence
            parent. This dict must be updated with global new values and
            must be used to perform evaluation (It always contains all the
            names defined in root_vars).

        missings : set
            Set of unfound local variables.

        errors : dict
            Dict of the errors which happened when performing the evaluation.

        Returns
        -------
        flag : bool
            Boolean indicating whether or not the evaluation succeeded.

        """
        raise NotImplementedError()

    def simplify_sequence(self):
        """Simplify the items found in the sequence.

        Pulses are always kept as they are, sequences are simplified based
        on the context ability to deal with them.

        Returns
        -------
        items : list
            List of items after simplification.

        """
        raise NotImplementedError()

    @classmethod
    def build_from_config(cls, config, dependencies):
        """ Create a new instance using the provided infos for initialisation.

        Parameters
        ----------
        config : dict(str)
            Dictionary holding the new values to give to the members in string
            format, or dictionnary like for instance with prefs.

        dependencies : dict
            Dictionary holding the necessary classes needed when rebuilding.

        Returns
        -------
        sequence : AbstractSequence
            Newly created and initiliazed sequence.

        """
        raise NotImplementedError()

    def traverse(self, depth=-1):
        """Traverse the items.

        """
        for i in super(AbstractSequence, self).traverse(depth=depth):
            yield i

        if depth == 0:
            for c in self.items:
                if c:
                    yield c

        else:
            for c in self.items:
                if c:
                    for subc in c.traverse(depth - 1):
                        yield subc

    # --- Private API ---------------------------------------------------------

    #: List of already evaluated items.
    _evaluated = List()

    def _evaluate_items(self, root_vars, sequence_locals, missings, errors):
        """Evaluate all the children item of the sequence

        Parameters
        ----------
        root_vars : dict
            Dictionary of global variables for the all items. This will
            tipically contains the i_start/stop/duration and the root vars.

        sequence_locals : dict
            Dictionary of variables whose scope is limited to this sequence.

        missings : set
            Set of unfound local variables.

        errors : dict
            Dict of the errors which happened when performing the evaluation.

        Returns
        -------
        flag : bool
            Boolean indicating whether or not the evaluation succeeded.

        """

        # Inplace modification during evaluation will update self._evaluated.
        if not self._evaluated:
            self._evaluated = [None for i in self.items if i.enabled]
        evaluated = self._evaluated

        # Compilation of items in multiple passes.
        while True:
            miss = set()

            index = -1
            for item in self.items:
                # Skip disabled items
                if not item.enabled:
                    continue

                # Increment index so that we set the right object in compiled.
                index += 1

                # Skip evaluation if object has already been compiled.
                if evaluated[index] is not None:
                    continue

                # If we get a pulse simply evaluate the entries, to add their
                # values to the locals and keep track of the missings to now
                # when to abort compilation.
                if isinstance(item, Pulse):
                    success = item.eval_entries(root_vars, sequence_locals,
                                                miss, errors)
                    if success:
                        evaluated[index] = [item]

                # Here we got a sequence so we must try to compile it.
                else:
                    success = item.evaluate_sequence(root_vars,
                                                     sequence_locals,
                                                     miss, errors)
                    if success:
                        evaluated[index] = item

            known_locals = set(sequence_locals.keys())
            # If none of the variables found missing during last pass is now
            # known stop compilation as we now reached a dead end. Same if an
            # error occured.
            if errors or miss and (not known_locals & miss):
                # Update the missings given by caller so that it knows it this
                # failure is linked to circle references.
                missings.update(miss)
                return False

            # If no var was found missing during last pass (and as no error
            # occured) it means the compilation succeeded.
            elif not miss:
                return True
Beispiel #17
0
class MeasureContainer(Atom):
    """Generic container for measures.

    """
    #: List containing the measures. This must not be manipulated directly
    #: by user code.
    measures = List()

    #: Signal used to notify changes to the stored measures.
    changed = Signal()

    def add(self, measure, index=None):
        """Add a measure to the stored ones.

        Parameters
        ----------
        measure : Measure
            Measure to add.

        index : int | None
            Index at which to insert the measure. If None the measure is
            appended.

        """
        notification = ContainerChange(obj=self, name='measures')
        if index is None:
            index = len(self.measures)
            self.measures.append(measure)
        else:
            self.measures.insert(index, measure)

        notification.add_operation('added', (index, measure))
        self.changed(notification)

    def remove(self, measures):
        """Remove a measure or a list of measure.

        Parameters
        ----------
        measures : Measure|list[Measure]
            Measure(s) to remove.

        """
        if not isinstance(measures, Iterable):
            measures = [measures]

        notification = ContainerChange(obj=self, name='measures')
        for measure in measures:
            old = self.measures.index(measure)
            del self.measures[old]
            notification.add_operation('removed', (old, measure))

        self.changed(notification)

    def move(self, old, new):
        """Move a measure.

        Parameters
        ----------
        old : int
            Index at which the measure to move currently is.

        new_position : int
            Index at which to insert the measure.

        """
        measure = self.measures[old]
        del self.measures[old]
        self.measures.insert(new, measure)

        notification = ContainerChange(obj=self, name='measures')
        notification.add_operation('moved', (old, new, measure))
        self.changed(notification)
Beispiel #18
0
 class EventAtom(Atom):
     s1 = Signal()
     s2 = Signal()
Beispiel #19
0
 class SignalAtom(Atom):
     s = Signal()
Beispiel #20
0
class TaskDatabase(Atom):
    """ A database for inter tasks communication.

    The database has two modes:

    - an edition mode in which the number of entries and their hierarchy
      can change. In this mode the database is represented by a nested dict.

    - a running mode in which the entries are fixed (only their values can
      change). In this mode the database is represented as a flat list.
      In running mode the database is thread safe but the object it contains
      may not be so (dict, list, etc)

    """
    #: Signal used to notify a value changed in the database.
    #: In edition mode the update is passed as a tuple ('added', path, value)
    #: for creation, as ('renamed', old, new, value) in case of renaming,
    #: ('removed', old) in case of deletion or as a list of such tuples.
    #: In running mode, a 2-tuple (path, value) is sent as entries cannot be
    #: renamed or removed.
    notifier = Signal()

    #: Signal emitted to notify that access exceptions has changed. The update
    #: is passed as a tuple ('added', path, relative, entry) for creation or as
    #: ('renamed', path, relative, old, new) in case of renaming of the related
    #: entry, ('removed', path, relative, old) in case of deletion (if old is
    #: None all  exceptions have been removed) or as a list of such tuples.
    #: Path indicate the node where the exception is located, relative the
    #: relative path from the 'path' node to the real location of the entry.
    access_notifier = Signal()

    #: Signal emitted to notify that the nodes were modified. The update
    #: is passed as a tuple ('added', path, name, node) for creation or as
    #: ('renamed', path, old, new) in case of renaming of the related node,
    #: ('removed', path, old) in case of deletion or as a list of such tuples.
    nodes_notifier = Signal()

    #: List of root entries which should not be listed.
    excluded = List(default=['threads', 'instrs'])

    #: Flag indicating whether or not the database entered the running mode. In
    #: running mode the database is flattened into a list for faster acces.
    running = Bool(False)

    def set_value(self, node_path, value_name, value):
        """Method used to set the value of the entry at the specified path

        This method can be used both in edition and running mode.

        Parameters
        ----------
        node_path : unicode
            Path to the node holding the value to be set

        value_name : unicode
            Public key associated with the value to be set, internally
            converted so that we do not mix value and nodes

        value : any
            Actual value to be stored

        Returns
        -------
        new_val : bool
            Boolean indicating whether or not a new entry has been created in
            the database

        """
        new_val = False
        if self.running:
            full_path = node_path + '/' + value_name
            index = self._entry_index_map[full_path]
            with self._lock:
                self._flat_database[index] = value
                self.notifier((node_path + '/' + value_name, value))
        else:
            node = self.go_to_path(node_path)
            if value_name not in node.data:
                new_val = True
            node.data[value_name] = value
            if new_val:
                self.notifier(('added', node_path + '/' + value_name, value))

        return new_val

    def get_value(self, assumed_path, value_name):
        """Method to get a value from the database from its name and a path

        This method returns the value stored under the specified name. It
        starts looking at the specified path and if necessary goes up in the
        hierarchy.

        Parameters
        ----------
        assumed_path : unicode
            Path where we start looking for the entry

        value_name : unicode
            Name of the value we are looking for

        Returns
        -------
        value : object
            Value stored under the entry value_name

        """
        if self.running:
            index = self._find_index(assumed_path, value_name)
            return self._flat_database[index]

        else:
            node = self.go_to_path(assumed_path)

            # First check if the entry is in the current node.
            if value_name in node.data:
                value = node.data[value_name]
                return value

            # Second check if there is a special rule about this entry.
            elif 'access' in node.meta and value_name in node.meta['access']:
                path = assumed_path + '/' + node.meta['access'][value_name]
                return self.get_value(path, value_name)

            # Finally go one step up in the node hierarchy.
            else:
                new_assumed_path = assumed_path.rpartition('/')[0]
                if assumed_path == new_assumed_path:
                    mes = "Can't find database entry : {}".format(value_name)
                    raise KeyError(mes)
                return self.get_value(new_assumed_path, value_name)

    def rename_values(self, node_path, old, new, access_exs=None):
        """Rename database entries.

        This method can update the access exceptions attached to them.
        This method cannot be used in running mode.

        Parameters
        ----------
        node_path : unicode
            Path to the node holding the value.

        old : iterable
            Old names of the values.

        new : iterable
            New names of the values.

        access_exs : iterable, optional
            Dict mapping old entries names to how far the access exception is
            located.

        """
        if self.running:
            raise RuntimeError('Cannot delete an entry in running mode')

        node = self.go_to_path(node_path)
        notif = []
        acc_notif = []
        access_exs = access_exs if access_exs else {}

        for i, old_name in enumerate(old):
            if old_name in node.data:
                val = node.data.pop(old_name)
                node.data[new[i]] = val
                notif.append(('renamed', node_path + '/' + old_name,
                              node_path + '/' + new[i], val))
                if old_name in access_exs:
                    count = access_exs[old_name]
                    n = node
                    p = node_path
                    while count:
                        n = n.parent if n.parent else n
                        p, _ = p.rsplit('/', 1)
                        count -= 1
                    path = n.meta['access'].pop(old_name)
                    n.meta['access'][new[i]] = path
                    acc_notif.append(('renamed', p, path, old_name, new[i]))
            else:
                err_str = 'No entry {} in node {}'.format(old_name, node_path)
                raise KeyError(err_str)

        # Avoid sending spurious notifications
        if notif:
            self.notifier(notif)
        if acc_notif:
            self.access_notifier(acc_notif)

    def delete_value(self, node_path, value_name):
        """Remove an entry from the specified node

        This method remove the specified entry from the specified node. It does
        not handle removing the access exceptions attached to it. This
        method cannot be used in running mode.

        Parameters
        ----------
        assumed_path : unicode
            Path where we start looking for the entry

        value_name : unicode
            Name of the value we are looking for

        """
        if self.running:
            raise RuntimeError('Cannot delete an entry in running mode')

        else:
            node = self.go_to_path(node_path)

            if value_name in node.data:
                del node.data[value_name]
                self.notifier(('removed', node_path + '/' + value_name))
            else:
                err_str = 'No entry {} in node {}'.format(
                    value_name, node_path)
                raise KeyError(err_str)

    def get_values_by_index(self, indexes, prefix=None):
        """Access to a list of values using the flat database.

        Parameters
        ----------
        indexes : list(int)
            List of index for which values should be returned.

        prefix : unicode, optional
            If provided return the values in dict with key of the form :
            prefix + index.

        Returns
        -------
        values : list or dict
            List of requested values in the same order as indexes or dict if
            prefix was not None.

        """
        if prefix is None:
            return [self._flat_database[i] for i in indexes]
        else:
            return {prefix + str(i): self._flat_database[i] for i in indexes}

    def get_entries_indexes(self, assumed_path, entries):
        """ Access to the index in the flattened database for some entries.

        Parameters
        ----------
        assumed_path : unicode
            Path to the node in which the values are assumed to be stored.

        entries : iterable(unicode)
            Names of the entries for which the indexes should be returned.

        Returns
        -------
        indexes : dict
            Dict mapping the entries names to their index in the flattened
            database.

        """
        return {name: self._find_index(assumed_path, name) for name in entries}

    def list_accessible_entries(self, node_path):
        """Method used to get a list of all entries accessible from a node.

        DO NOT USE THIS METHOD IN RUNNING MODE (ie never in the check method
        of a task, use a try except clause instead and get_value or
        get_entries_indexes).

        Parameters
        ----------
        node_path : unicode
            Path to the node from which accessible entries should be listed.

        Returns
        -------
        entries_list : list(unicode)
            List of entries accessible from the specified node

        """
        entries = []
        while True:
            node = self.go_to_path(node_path)
            keys = node.data.keys()
            # Looking for the entries in the node.
            for key in keys:
                if not isinstance(node.data[key], DatabaseNode):
                    entries.append(key)

            # Adding the special access if they are not already in the list.
            for entry in node.meta.get('access', []):
                if entry not in entries:
                    entries.append(entry)

            if node_path != 'root':
                # Going to the next node.
                node_path = node_path.rpartition('/')[0]
            else:
                break

        for entry in self.excluded:
            if entry in entries:
                entries.remove(entry)

        return sorted(entries)

    def list_all_entries(self, path='root', values=False):
        """List all entries in the database.

        Parameters
        ----------
        path : unicode, optional
            Starting node. This parameters is for internal use only.

        values : bool, optional
            Whether or not to return the values associated with the entries.

        Returns
        -------
        paths : list(unicode) or dict if values
            List of all accessible entries with their full path.

        """
        entries = [] if not values else {}
        node = self.go_to_path(path)
        for entry in node.data.keys():
            if isinstance(node.data[entry], DatabaseNode):
                aux = self.list_all_entries(path=path + '/' + entry,
                                            values=values)
                if not values:
                    entries.extend(aux)
                else:
                    entries.update(aux)
            else:
                if not values:
                    entries.append(path + '/' + entry)
                else:
                    entries[path + '/' + entry] = node.data[entry]

        if path == 'root':
            for entry in self.excluded:
                aux = path + '/' + entry
                if aux in entries:
                    if not values:
                        entries.remove(aux)
                    else:
                        del entries[aux]

        return sorted(entries) if not values else entries

    def add_access_exception(self, node_path, entry_node, entry):
        """Add an access exception in a node for an entry located in a node
        below.

        Parameters
        ----------
        node_path : unicode
            Path to the node which should hold the exception.

        entry_node : unicode
            Absolute path to the node holding the entry.

        entry : unicode
            Name of the entry for which to create an exception.

        """
        node = self.go_to_path(node_path)
        rel_path = entry_node[len(node_path) + 1:]
        if 'access' in node.meta:
            access_exceptions = node.meta['access']
            access_exceptions[entry] = rel_path
        else:
            node.meta['access'] = {entry: rel_path}
        self.access_notifier(('added', node_path, rel_path, entry))

    def remove_access_exception(self, node_path, entry=None):
        """Remove an access exception from a node for a given entry.

        Parameters
        ----------
        node_path : unicode
            Path to the node holding the exception.

        entry : unicode, optional
            Name of the entry for which to remove the exception, if not
            provided all access exceptions will be removed.

        """
        node = self.go_to_path(node_path)
        if entry:
            access_exceptions = node.meta['access']
            relative_path = access_exceptions[entry]
            del access_exceptions[entry]
        else:
            relative_path = ''
            del node.meta['access']
        self.access_notifier(('removed', node_path, relative_path, entry))

    def create_node(self, parent_path, node_name):
        """Method used to create a new node in the database

        This method creates a new node in the database at the specified path.
        This method is not thread safe safe as the hierarchy of the tasks'
        database is not supposed to change during a measurement but only during
        the configuration phase

        Parameters
        ----------
        parent_path : unicode
            Path to the node parent of the new one

        node_name : unicode
            Name of the new node to create

        """
        if self.running:
            raise RuntimeError('Cannot create a node in running mode')

        parent_node = self.go_to_path(parent_path)
        node = DatabaseNode(parent=parent_node)
        parent_node.data[node_name] = node
        self.nodes_notifier(('added', parent_path, node_name, node))

    def rename_node(self, parent_path, old_name, new_name):
        """Method used to rename a node in the database

        Parameters
        ----------
        parent_path : unicode
            Path to the parent of the node being renamed

        old_name : unicode
            Old name of the node.

        node_name : unicode
            New name of node

        """
        if self.running:
            raise RuntimeError('Cannot rename a node in running mode')

        parent_node = self.go_to_path(parent_path)
        parent_node.data[new_name] = parent_node.data[old_name]
        del parent_node.data[old_name]

        while parent_node:
            if 'access' not in parent_node.meta:
                parent_node = parent_node.parent
                continue
            access = parent_node.meta['access'].copy()
            for k, v in access.items():
                if old_name in v:
                    new_path = v.replace(old_name, new_name)
                    parent_node.meta['access'][k] = new_path

            parent_node = parent_node.parent

        self.nodes_notifier(('renamed', parent_path, old_name, new_name))

    def delete_node(self, parent_path, node_name):
        """Method used to delete an existing node from the database

        Parameters
        ----------
        parent_path : unicode
            Path to the node parent of the new one

        node_name : unicode
            Name of the new node to create

        """
        if self.running:
            raise RuntimeError('Cannot delete a node in running mode')

        parent_node = self.go_to_path(parent_path)
        if node_name in parent_node.data:
            del parent_node.data[node_name]
        else:
            err_str = 'No node {} at the path {}'.format(
                node_name, parent_path)
            raise KeyError(err_str)

        self.nodes_notifier(('removed', parent_path, node_name))

    def copy_node_values(self, node='root'):
        """Copy the values (ie not subnodes) found in a node.

        Parameters
        ----------
        node : unicode, optional
            Path to the node to copy.

        Returns
        -------
        copy : dict
            Copy of the node values.

        """
        node = self.go_to_path(node)
        return {
            k: v
            for k, v in node.data.items() if not isinstance(v, DatabaseNode)
        }

    def prepare_to_run(self):
        """Enter a thread safe, flat database state.

        This is used when tasks are executed.

        """
        self._lock = Lock()
        self.running = True

        # Flattening the database by walking all the nodes.
        index = 0
        nodes = [('root', self._database)]
        mapping = {}
        datas = []
        for (node_path, node) in nodes:
            for key, val in node.data.items():
                path = node_path + '/' + key
                if isinstance(val, DatabaseNode):
                    nodes.append((path, val))
                else:
                    mapping[path] = index
                    index += 1
                    datas.append(val)

        # Walking a second time to add the exception to the _entry_index_map,
        # in reverse order in case an entry has multiple exceptions.
        for (node_path, node) in nodes[::-1]:
            access = node.meta.get('access', [])
            for entry in access:
                short_path = node_path + '/' + entry
                full_path = node_path + '/' + access[entry] + '/' + entry
                mapping[short_path] = mapping[full_path]

        self._flat_database = datas
        self._entry_index_map = mapping

        self._database = None

    def list_nodes(self):
        """List all the nodes present in the database.

        Returns
        -------
        nodes : dict
            Dictionary storing the nodes by path

        """
        nodes = [('root', self._database)]
        for (node_path, node) in nodes:
            for key, val in node.data.items():
                if isinstance(val, DatabaseNode):
                    path = node_path + '/' + key
                    nodes.append((path, val))

        return dict(nodes)

    def go_to_path(self, path):
        """Method used to reach a node specified by a path.

        """
        node = self._database
        if path == 'root':
            return node

        # Decompose the path in database keys
        keys = path.split('/')
        # Remove first key (ie 'root' as we are not trying to access it)
        del keys[0]

        for key in keys:
            if key in node.data:
                node = node.data[key]
            else:
                ind = keys.index(key)
                if ind == 0:
                    err_str = \
                        'Path {} is invalid, no node {} in root'.format(path,
                                                                        key)
                else:
                    err_str = 'Path {} is invalid, no node {} in node\
                        {}'.format(path, key, keys[ind - 1])
                raise KeyError(err_str)

        return node

    # =========================================================================
    # --- Private API ---------------------------------------------------------
    # =========================================================================

    #: Main container for the database.
    _database = Typed(DatabaseNode, ())

    #: Flat version of the database only used in running mode for perfomances
    #: issues.
    _flat_database = List()

    #: Dict mapping full paths to flat database indexes.
    _entry_index_map = Dict()

    #: Lock to make the database thread safe in running mode.
    _lock = Value()

    def _find_index(self, assumed_path, entry):
        """Find the index associated with a path.

        Only to be used in running mode.

        """
        path = assumed_path
        while path != 'root':
            full_path = path + '/' + entry
            if full_path in self._entry_index_map:
                return self._entry_index_map[full_path]
            path = path.rpartition('/')[0]

        full_path = path + '/' + entry
        if full_path in self._entry_index_map:
            return self._entry_index_map[full_path]

        raise KeyError("Can't find entry matching {}, {}".format(
            assumed_path, entry))
Beispiel #21
0
class NodeModel(Atom):
    """Object representing the database node state linked to a ComplexTask

    """
    #: Reference to the task this node refers to.
    task = Typed(ComplexTask)

    #: Reference to editor model.
    editor = ForwardTyped(lambda: EditorModel)

    #: Database entries available on the node associated with the task.
    entries = List()

    #: Database exceptions present on the node.
    exceptions = List()

    #: Database entries for which an access exception exists
    has_exceptions = List()

    #: Reference to the node which a parent of this one.
    parent = ForwardTyped(lambda: NodeModel)

    #: Children nodes
    children = List()

    #: Notifier for changes to the children. Simply there to satisfy the
    #: TaskEditor used in the view.
    children_changed = Signal()

    def __init__(self, **kwargs):

        super(NodeModel, self).__init__(**kwargs)
        for m in tagged_members(self.task, 'child_notifier'):
            self.task.observe(m, self._react_to_task_children_event)

    def sort_nodes(self):
        """Sort the nodes according to the task order.

        """
        tasks = [
            t for t in self.task.gather_children()
            if isinstance(t, ComplexTask)
        ]
        self.children = sorted(self.children,
                               key=lambda n: tasks.index(n.task))

    def add_exception(self, entry):
        """Add an access exception.

        """
        task, entry = self._find_task_from_entry(entry)

        if entry not in task.access_exs:
            task.add_access_exception(entry, 1)

    # =========================================================================
    # --- Private API ---------------------------------------------------------
    # =========================================================================

    def _react_to_task_children_event(self, change):
        """Simply reorder the nodes if it was a move event.

        Only move events are transparent to the database.

        """
        if isinstance(change, ContainerChange):
            if change.collapsed:
                for c in change.collapsed:
                    self._react_to_task_children_event(c)

            if change.moved:
                self.sort_nodes()

    def _find_task_from_entry(self, full_entry):
        """Find the task and short name corresponding to a full entry name.

        """
        possible_tasks = [
            t for t in self.task.gather_children()
            if full_entry.startswith(t.name)
        ]
        if len(possible_tasks) > 1:
            for p in possible_tasks:
                e = full_entry[len(p.name) + 1:]
                if e in p.database_entries:
                    break
            task = p
            entry = e
        else:
            task = possible_tasks[0]
            entry = full_entry[len(task.name) + 1:]

        return task, entry
Beispiel #22
0
    class SignalAtom(type(sd_observed_atom)):
        val = Signal()

        def _observe_val(self, change):
            self.count += 1
Beispiel #23
0
class ComplexTask(BaseTask):
    """Task composed of several subtasks.

    """
    #: List of all the children of the task. The list should not be manipulated
    #: directly by user code.
    #: The tag 'child' is used to mark that a member can contain child tasks
    #: and is used gather children for operation which must occur on all of
    #: them.
    children = List().tag(child=100)

    #: Signal emitted when the list of children change, the payload will be a
    #: ContainerChange instance.
    #: The tag 'child_notifier' is used to mark that a member emmit
    #: notifications about modification of another 'child' member. This allow
    #: editors to correctly track all of those.
    children_changed = Signal().tag(child_notifier='children')

    #: Flag indicating whether or not the task has a root task.
    has_root = Bool(False)

    def perform(self):
        """Run sequentially all child tasks.

        """
        for child in self.children:
            child.perform_()

    def check(self, *args, **kwargs):
        """Run test of all child tasks.

        """
        test, traceback = super(ComplexTask, self).check(*args, **kwargs)
        for child in self.gather_children():
            try:
                check = child.check(*args, **kwargs)
                test = test and check[0]
                traceback.update(check[1])
            except Exception:
                test = False
                msg = 'An exception occured while running check :\n%s'
                traceback[child.path + '/' + child.name] = msg % format_exc()

        return test, traceback

    def prepare(self):
        """Overridden to prepare also children tasks.

        """
        super(ComplexTask, self).prepare()
        for child in self.gather_children():
            child.prepare()

    def add_child_task(self, index, child):
        """Add a child task at the given index.

        Parameters
        ----------
        index : int
            Index at which to insert the new child task.

        child : BaseTask
            Task to insert in the list of children task.

        """
        self.children.insert(index, child)

        # In the absence of a root task do nothing else than inserting the
        # child.
        if self.has_root:
            child.depth = self.depth + 1
            child.database = self.database
            child.path = self._child_path()

            # Give him its root so that it can proceed to any child
            # registration it needs to.
            child.parent = self
            child.root = self.root

            # Ask the child to register in database
            child.register_in_database()

            # Register anew preferences to keep the right ordering for the
            # children
            self.register_preferences()

            change = ContainerChange(obj=self,
                                     name='children',
                                     added=[(index, child)])
            self.children_changed(change)

    def move_child_task(self, old, new):
        """Move a child task.

        Parameters
        ----------
        old : int
            Index at which the child to move is currently located.

        new : BaseTask
            Index at which to insert the child task.

        """
        child = self.children.pop(old)
        self.children.insert(new, child)

        # In the absence of a root task do nothing else than moving the
        # child.
        if self.has_root:
            # Register anew preferences to keep the right ordering for the
            # children
            self.register_preferences()

            change = ContainerChange(obj=self,
                                     name='children',
                                     moved=[(old, new, child)])
            self.children_changed(change)

    def remove_child_task(self, index):
        """Remove a child task from the children list.

        Parameters
        ----------
        index : int
            Index at which the child to remove is located.

        """
        child = self.children.pop(index)

        # Cleanup database, update preferences
        child.unregister_from_database()
        child.root = None
        child.parent = None
        self.register_preferences()

        change = ContainerChange(obj=self,
                                 name='children',
                                 removed=[(index, child)])
        self.children_changed(change)

    def gather_children(self):
        """Build a flat list of all children task.

        Children tasks are ordered according to their 'child' tag value.

        Returns
        -------
        children : list
            List of all the task children.

        """
        children = []
        tagged = tagged_members(self, 'child')
        for name in sorted(tagged, key=lambda m: tagged[m].metadata['child']):

            child = getattr(self, name)
            if child:
                if isinstance(child, Iterable):
                    children.extend(child)
                else:
                    children.append(child)

        return children

    def traverse(self, depth=-1):
        """Reimplemented to yield all child task.

        """
        yield self

        if depth == 0:
            for c in self.gather_children():
                if c:
                    yield c

        else:
            for c in self.gather_children():
                if c:
                    for subc in c.traverse(depth - 1):
                        yield subc

    def register_in_database(self):
        """Create a node in the database and register all entries.

        This method registers both the task entries and all the tasks tagged
        as child.

        """
        super(ComplexTask, self).register_in_database()
        self.database.create_node(self.path, self.name)

        # ComplexTask defines children so we always get something
        for child in self.gather_children():
            child.register_in_database()

    def unregister_from_database(self):
        """Unregister all entries and delete associated database node.

        This method unregisters both the task entries and all the tasks tagged
        as child.

        """
        super(ComplexTask, self).unregister_from_database()

        for child in self.gather_children():
            child.unregister_from_database()

        self.database.delete_node(self.path, self.name)

    def register_preferences(self):
        """Register the task preferences into the preferences system.

        This method registers both the task preferences and all the
        preferences of the tasks tagged as child.

        """

        self.preferences.clear()
        for name, member in tagged_members(self, 'pref').items():
            # Register preferences.
            val = getattr(self, name)
            self.preferences[name] = member_to_pref(self, member, val)

        # Find all tagged children.
        for name in tagged_members(self, 'child'):
            child = getattr(self, name)
            if child:
                if isinstance(child, Iterable):
                    for i, aux in enumerate(child):
                        child_id = name + '_{}'.format(i)
                        self.preferences[child_id] = {}
                        aux.preferences = \
                            self.preferences[child_id]
                        aux.register_preferences()
                else:
                    self.preferences[name] = {}
                    child.preferences = self.preferences[name]
                    child.register_preferences()

    def update_preferences_from_members(self):
        """Update the values stored in the preference system.

        This method updates both the task preferences and all the
        preferences of the tasks tagged as child.

        """
        for name, member in tagged_members(self, 'pref').items():
            val = getattr(self, name)
            self.preferences[name] = member_to_pref(self, member, val)

        for child in self.gather_children():
            child.update_preferences_from_members()

    @classmethod
    def build_from_config(cls, config, dependencies):
        """Create a new instance using the provided infos for initialisation.

        Parameters
        ----------
        config : dict(str)
            Dictionary holding the new values to give to the members in string
            format, or dictionnary like for instance with prefs.

        dependencies : dict
            Dictionary holding the necessary classes needed when rebuilding.
            This is assembled by the TaskManager.

        Returns
        -------
        task : BaseTask
            Newly created and initiliazed task.

        Notes
        -----
        This method is fairly powerful and can handle a lot of cases so
        don't override it without checking that it works.

        """
        task = cls()
        update_members_from_preferences(task, config)
        for name, member in tagged_members(task, 'child').items():

            if isinstance(member, List):
                i = 0
                pref = name + '_{}'
                validated = []
                while True:
                    child_name = pref.format(i)
                    if child_name not in config:
                        break
                    child_config = config[child_name]
                    child_class_name = child_config.pop('task_id')
                    child_cls = dependencies[DEP_TYPE][child_class_name]
                    child = child_cls.build_from_config(
                        child_config, dependencies)
                    validated.append(child)
                    i += 1

            else:
                if name not in config:
                    continue
                child_config = config[name]
                child_class_name = child_config.pop('task_id')
                child_class = dependencies[DEP_TYPE][child_class_name]
                validated = child_class.build_from_config(
                    child_config, dependencies)

            setattr(task, name, validated)

        return task

    # =========================================================================
    # --- Private API ---------------------------------------------------------
    # =========================================================================

    #: Last removed child and list of database access exceptions attached to
    #: it and necessity to observe its _access_exs.
    _last_removed = Tuple(default=(None, None, False))

    #: Last access exceptions desactivated from a child.
    _last_exs = Coerced(set)

    #: List of access_exs, linked to access exs in child, disabled because
    #: child disabled some access_exs.
    _disabled_exs = List()

    def _child_path(self):
        """Convenience function returning the path to set for child task.

        """
        return self.path + '/' + self.name

    def _update_children_path(self):
        """Update the path of all children.

        """
        for child in self.gather_children():
            child.path = self._child_path()
            if isinstance(child, ComplexTask):
                child._update_children_path()

    def _post_setattr_name(self, old, new):
        """Handle the task being renamed at runtime.

        If the task is renamed at runtime, it means that the path of all the
        children task is now obselete and that the database node
        of this task must be renamed (database handles the exception.

        """
        if old and self.database:
            super(ComplexTask, self)._post_setattr_name(old, new)
            self.database.rename_node(self.path, old, new)

            # Update the path of all children.
            self._update_children_path()

    def _post_setattr_root(self, old, new):
        """Make sure that all children get all the info they need to behave
        correctly when the task get its root parent (ie the task is now
        in a 'correct' environnement).

        """
        if new is None:
            return

        self.has_root = True
        for child in self.gather_children():
            child.depth = self.depth + 1
            child.database = self.database
            child.path = self._child_path()

            # Give him its root so that it can proceed to any child
            # registration it needs to.
            child.parent = self
            child.root = self.root
Beispiel #24
0
    event_handler: not tested here (see test_observe.py)
    signal_handler
    delegate_handler: not tested here (see test_delegate.py)
    property_handler: not tested here (see test_property.py)
    call_object_object_value_handler: not used as far as I can tell
    call_object_object_name_value_handler: not used as far as I can tell
    object_method_value_handler: not used as far as I can tell
    object_method_name_value_handler: not used as far as I can tell
    member_method_object_value_handler: method defined on a Member subclass

"""
import pytest
from atom.api import (Atom, Int, Constant, Signal, ReadOnly, SetAttr)


@pytest.mark.parametrize("member", [(Signal(), ), (Constant(1), )])
def test_unsettable(member):
    """Test that unsettable members do raise the proper error.

    """
    class Unsettable(Atom):

        m = member

    u = Unsettable()
    with pytest.raises(AttributeError):
        u.m = None


def test_read_only_behavior():
    """Test the behavior of read only member.
Beispiel #25
0
class BaseEngine(Atom):
    """Base class for all engines.

    """
    #: Declaration defining this engine.
    declaration = ForwardTyped(lambda: Engine)

    #: Current status of the engine.
    status = Enum('Stopped', 'Waiting', 'Running', 'Pausing', 'Paused',
                  'Resuming', 'Stopping', 'Shutting down')

    #: Signal used to pass news about the measurement progress.
    progress = Signal()

    def perform(self, exec_infos):
        """Execute a given task and catch any error.

        Parameters
        ----------
        exec_infos : ExecutionInfos
            TaskInfos object describing the work to expected of the engine.

        Returns
        -------
        exec_infos : ExecutionInfos
            Input object whose values have been updated. This is simply a
            convenience.

        """
        raise NotImplementedError()

    def pause(self):
        """Ask the engine to pause the execution.

        This method should not wait for the task to pause to return.
        When the pause is effective the engine should signal it by updating its
        status.

        """
        raise NotImplementedError()

    def resume(self):
        """Ask the engine to resume the execution.

        This method should not wait for the measure to resume.
        When the pause is over the engine should signal it by updating its
        status.

        """
        raise NotImplementedError()

    def stop(self, force=False):
        """Ask the engine to stop the execution.

        This method should not wait for the execution to stop save if a forced
        stop was requested.

        Parameters
        ----------
        force : bool, optional
            Force the engine to stop the performing the task. This allow the
            engine to use any means necessary to stop, in this case only should
            the call to this method block.

        """
        raise NotImplementedError()

    def shutdown(self, force=False):
        """Ask the engine to stop completely.

        After a call to this method the engine may need to re-initialize a
        number of things before running the next task.
        This method should not wait for the engine to shutdown save if a
        forced stop was requested.

        Parameters
        ----------
        force : bool, optional
            Force the engine to stop the performing the task. This allow the
            engine to use any means necessary to stop, in this case only should
            the call to this method block.

        """
        raise NotImplementedError()
Beispiel #26
0
    class A(Atom):

        m = member

    assert A.m.delattr_mode[0] == DelAttr.NoOp
    a = A()
    a.m = 1
    del a.m
    assert a.m == 1
    assert A.m.do_delattr(a) is None
    assert a.m == 1


@pytest.mark.parametrize("member, mode", [(Event(), DelAttr.Event),
                                          (Signal(), DelAttr.Signal),
                                          (ReadOnly(), DelAttr.ReadOnly),
                                          (Constant(1), DelAttr.Constant)])
def test_undeletable(member, mode):
    """Test that unsettable members do raise the proper error.

    """
    class Undeletable(Atom):

        m = member

    assert Undeletable.m.delattr_mode[0] == mode
    u = Undeletable()
    with pytest.raises(TypeError):
        del u.m
    with pytest.raises(TypeError):