コード例 #1
0
class AbstractWindow(HasTraits):

    # The top-level component that this window houses
    component = Instance(Component)

    # A reference to the nested component that has focus.  This is part of the
    # manual mechanism for determining keyboard focus.
    focus_owner = Instance(Interactor)

    # If set, this is the component to which all mouse events are passed,
    # bypassing the normal event propagation mechanism.
    mouse_owner = Instance(Interactor)

    # The transform to apply to mouse event positions to put them into the
    # relative coordinates of the mouse_owner component.
    mouse_owner_transform = Any()

    # When a component captures the mouse, it can optionally store a
    # dispatch order for events (until it releases the mouse).
    mouse_owner_dispatch_history = Trait(None, None, List)

    # The background window of the window.  The entire window first gets
    # painted with this color before the component gets to draw.
    bgcolor = ColorTrait("sys_window")

    # Unfortunately, for a while, there was a naming inconsistency and the
    # background color trait named "bg_color".  This is still provided for
    # backwards compatibility but should not be used in new code.
    bg_color = Alias("bgcolor")

    alt_pressed = Bool(False)
    ctrl_pressed = Bool(False)
    shift_pressed = Bool(False)

    # A container that gets drawn after & on top of the main component, and
    # which receives events first.
    overlay = Instance(Container)

    # When the underlying toolkit control gets resized, this event gets set
    # to the new size of the window, expressed as a tuple (dx, dy).
    resized = Event

    # Whether to enable damaged region handling
    use_damaged_region = Bool(False)

    # The previous component that handled an event.  Used to generate
    # mouse_enter and mouse_leave events.  Right now this can only be
    # None, self.component, or self.overlay.
    _prev_event_handler = Instance(Component)

    # (dx, dy) integer size of the Window.
    _size = Trait(None, Tuple)

    # The regions to update upon redraw
    _update_region = Any

    #---------------------------------------------------------------------------
    #  Abstract methods that must be implemented by concrete subclasses
    #---------------------------------------------------------------------------

    def set_drag_result(self, result):
        """ Sets the result that should be returned to the system from the
        handling of the current drag operation.  Valid result values are:
        "error", "none", "copy", "move", "link", "cancel".  These have the
        meanings associated with their WX equivalents.
        """
        raise NotImplementedError

    def _capture_mouse(self):
        "Capture all future mouse events"
        raise NotImplementedError

    def _release_mouse(self):
        "Release the mouse capture"
        raise NotImplementedError

    def _create_key_event(self, event):
        "Convert a GUI toolkit key event into a KeyEvent"
        raise NotImplementedError

    def _create_mouse_event(self, event):
        "Convert a GUI toolkit mouse event into a MouseEvent"
        raise NotImplementedError

    def _redraw(self, coordinates=None):
        """ Request a redraw of the window, within just the (x,y,w,h) coordinates
        (if provided), or over the entire window if coordinates is None.
        """
        raise NotImplementedError

    def _get_control_size(self):
        "Get the size of the underlying toolkit control"
        raise NotImplementedError

    def _create_gc(self, size, pix_format="bgr24"):
        """ Create a Kiva graphics context of a specified size.  This method
        only gets called when the size of the window itself has changed.  To
        perform pre-draw initialization every time in the paint loop, use
        _init_gc().
        """
        raise NotImplementedError

    def _init_gc(self):
        """ Gives a GC a chance to initialize itself before components perform
        layout and draw.  This is called every time through the paint loop.
        """
        gc = self._gc
        if self._update_region == [] or not self.use_damaged_region:
            self._update_region = None
        if self._update_region is None:
            gc.clear(self.bgcolor_)
        else:
            # Fixme: should use clip_to_rects
            update_union = sm.reduce(union_bounds, self._update_region)
            gc.clip_to_rect(*update_union)
        return

    def _window_paint(self, event):
        "Do a GUI toolkit specific screen update"
        raise NotImplementedError

    def set_pointer(self, pointer):
        "Sets the current cursor shape"
        raise NotImplementedError

    def set_timer_interval(self, component, interval):
        "Set up or cancel a timer for a specified component"
        raise NotImplementedError

    def _set_focus(self):
        "Sets this window to have keyboard focus"
        raise NotImplementedError

    def screen_to_window(self, x, y):
        "Returns local window coordinates for given global screen coordinates"
        raise NotImplementedError

    def get_pointer_position(self):
        "Returns the current pointer position in local window coordinates"
        raise NotImplementedError

    #------------------------------------------------------------------------
    # Public methods
    #------------------------------------------------------------------------

    def __init__(self, **traits):
        self._scroll_origin = (0.0, 0.0)
        self._update_region = None
        self._gc = None
        self._pointer_owner = None
        HasTraits.__init__(self, **traits)

        # Create a default component (if necessary):
        if self.component is None:
            self.component = Container()
        return

    def _component_changed(self, old, new):
        if old is not None:
            old.on_trait_change(self.component_bounds_changed,
                                'bounds',
                                remove=True)
            old.window = None

        if new is None:
            self.component = Container()
            return

        new.window = self

        # If possible, size the new component according to the size of the
        # toolkit control
        size = self._get_control_size()
        if (size is not None) and hasattr(self.component, "bounds"):
            new.on_trait_change(self.component_bounds_changed, 'bounds')
            if getattr(self.component, "fit_window", False):
                self.component.outer_position = [0, 0]
                self.component.outer_bounds = list(size)
            elif hasattr(self.component, "resizable"):
                if "h" in self.component.resizable:
                    self.component.outer_x = 0
                    self.component.outer_width = size[0]
                if "v" in self.component.resizable:
                    self.component.outer_y = 0
                    self.component.outer_height = size[1]
        self._update_region = None
        self.redraw()
        return

    def component_bounds_changed(self, bounds):
        """
        Dynamic trait listener that handles our component changing its size;
        bounds is a length-2 list of [width, height].
        """
        self.invalidate_draw()
        pass

    def set_mouse_owner(self, mouse_owner, transform=None, history=None):
        "Handle the 'mouse_owner' being changed"
        if mouse_owner is None:
            self._release_mouse()
            self.mouse_owner = None
            self.mouse_owner_transform = None
            self.mouse_owner_dispatch_history = None
        else:
            self._capture_mouse()
            self.mouse_owner = mouse_owner
            self.mouse_owner_transform = transform
            self.mouse_owner_dispatch_history = history
        return

    def invalidate_draw(self, damaged_regions=None, self_relative=False):
        if damaged_regions is not None and self._update_region is not None:
            self._update_region += damaged_regions
        else:
            self._update_region = None
#        print damaged_regions

#---------------------------------------------------------------------------
#  Generic keyboard event handler:
#---------------------------------------------------------------------------

    def _handle_key_event(self, event_type, event):
        """ **event** should be a toolkit-specific opaque object that will
        be passed in to the backend's _create_key_event() method. It can
        be None if the the toolkit lacks a native "key event" object.

        Returns True if the event has been handled within the Enable object
        hierarchy, or False otherwise.
        """
        # Generate the Enable event
        key_event = self._create_key_event(event_type, event)
        if key_event is None:
            return False

        self.shift_pressed = key_event.shift_down
        self.alt_pressed = key_event.alt_down
        self.control_pressed = key_event.control_down

        # Dispatch the event to the correct component
        mouse_owner = self.mouse_owner
        if mouse_owner is not None:
            history = self.mouse_owner_dispatch_history
            if history is not None and len(history) > 0:
                # Assemble all the transforms
                transforms = [c.get_event_transform() for c in history]
                total_transform = sm.reduce(dot, transforms[::-1])
                key_event.push_transform(total_transform)
            elif self.mouse_owner_transform is not None:
                key_event.push_transform(self.mouse_owner_transform)

            mouse_owner.dispatch(key_event, event_type)
        else:
            # Normal event handling loop
            if (not key_event.handled) and (self.component is not None):
                if self.component.is_in(key_event.x, key_event.y):
                    # Fire the actual event
                    self.component.dispatch(key_event, event_type)

        return key_event.handled

    #---------------------------------------------------------------------------
    #  Generic mouse event handler:
    #---------------------------------------------------------------------------
    def _handle_mouse_event(self, event_name, event, set_focus=False):
        """ **event** should be a toolkit-specific opaque object that will
        be passed in to the backend's _create_mouse_event() method.  It can
        be None if the the toolkit lacks a native "mouse event" object.

        Returns True if the event has been handled within the Enable object
        hierarchy, or False otherwise.
        """
        if self._size is None:
            # PZW: Hack!
            # We need to handle the cases when the window hasn't been painted yet, but
            # it's gotten a mouse event.  In such a case, we just ignore the mouse event.
            # If the window has been painted, then _size will have some sensible value.
            return False

        mouse_event = self._create_mouse_event(event)
        # if no mouse event generated for some reason, return
        if mouse_event is None:
            return False

        mouse_owner = self.mouse_owner

        if mouse_owner is not None:
            # A mouse_owner has grabbed the mouse.  Check to see if we need to
            # compose a net transform by querying each of the objects in the
            # dispatch history in turn, or if we can just apply a saved top-level
            # transform.
            history = self.mouse_owner_dispatch_history
            if history is not None and len(history) > 0:
                # Assemble all the transforms
                transforms = [c.get_event_transform() for c in history]
                total_transform = sm.reduce(dot, transforms[::-1])
                mouse_event.push_transform(total_transform)
            elif self.mouse_owner_transform is not None:
                mouse_event.push_transform(self.mouse_owner_transform)

            mouse_owner.dispatch(mouse_event, event_name)
            self._pointer_owner = mouse_owner
        else:
            # Normal event handling loop
            if self.overlay is not None:
                # TODO: implement this...
                pass
            if (not mouse_event.handled) and (self.component is not None):
                # Test to see if we need to generate a mouse_leave event
                if self._prev_event_handler:
                    if not self._prev_event_handler.is_in(
                            mouse_event.x, mouse_event.y):
                        self._prev_event_handler.dispatch(
                            mouse_event, "pre_mouse_leave")
                        mouse_event.handled = False
                        self._prev_event_handler.dispatch(
                            mouse_event, "mouse_leave")
                        self._prev_event_handler = None

                if self.component.is_in(mouse_event.x, mouse_event.y):
                    # Test to see if we need to generate a mouse_enter event
                    if self._prev_event_handler != self.component:
                        self._prev_event_handler = self.component
                        self.component.dispatch(mouse_event, "pre_mouse_enter")
                        mouse_event.handled = False
                        self.component.dispatch(mouse_event, "mouse_enter")

                    # Fire the actual event
                    self.component.dispatch(mouse_event, "pre_" + event_name)
                    mouse_event.handled = False
                    self.component.dispatch(mouse_event, event_name)

        # If this event requires setting the keyboard focus, set the first
        # component under the mouse pointer that accepts focus as the new focus
        # owner (otherwise, nobody owns the focus):
        if set_focus:
            # If the mouse event was a click, then we set the toolkit's
            # focus to ourselves
            if mouse_event.left_down or mouse_event.middle_down or \
                    mouse_event.right_down or mouse_event.mouse_wheel != 0:
                self._set_focus()

            if (self.component is not None) and (self.component.accepts_focus):
                if self.focus_owner is None:
                    self.focus_owner = self.component
                else:
                    pass

        return mouse_event.handled

    #---------------------------------------------------------------------------
    #  Generic drag event handler:
    #---------------------------------------------------------------------------
    def _handle_drag_event(self, event_name, event, set_focus=False):
        """ **event** should be a toolkit-specific opaque object that will
        be passed in to the backend's _create_drag_event() method.  It can
        be None if the the toolkit lacks a native "drag event" object.

        Returns True if the event has been handled within the Enable object
        hierarchy, or False otherwise.
        """
        if self._size is None:
            # PZW: Hack!
            # We need to handle the cases when the window hasn't been painted yet, but
            # it's gotten a mouse event.  In such a case, we just ignore the mouse event.
            # If the window has been painted, then _size will have some sensible value.
            return False

        drag_event = self._create_drag_event(event)
        # if no mouse event generated for some reason, return
        if drag_event is None:
            return False

        if self.component is not None:
            # Test to see if we need to generate a drag_leave event
            if self._prev_event_handler:
                if not self._prev_event_handler.is_in(drag_event.x,
                                                      drag_event.y):
                    self._prev_event_handler.dispatch(drag_event,
                                                      "pre_drag_leave")
                    drag_event.handled = False
                    self._prev_event_handler.dispatch(drag_event, "drag_leave")
                    self._prev_event_handler = None

            if self.component.is_in(drag_event.x, drag_event.y):
                # Test to see if we need to generate a mouse_enter event
                if self._prev_event_handler != self.component:
                    self._prev_event_handler = self.component
                    self.component.dispatch(drag_event, "pre_drag_enter")
                    drag_event.handled = False
                    self.component.dispatch(drag_event, "drag_enter")

                # Fire the actual event
                self.component.dispatch(drag_event, "pre_" + event_name)
                drag_event.handled = False
                self.component.dispatch(drag_event, event_name)

        return drag_event.handled

    def set_tooltip(self, components):
        "Set the window's tooltip (if necessary)"
        raise NotImplementedError

    def redraw(self):
        """ Requests that the window be redrawn. """
        self._redraw()
        return

    def cleanup(self):
        """ Clean up after ourselves.
        """
        if self.component is not None:
            self.component.cleanup(self)
            self.component.parent = None
            self.component.window = None
            self.component = None

        self.control = None
        if self._gc is not None:
            self._gc.window = None
            self._gc = None

    def _needs_redraw(self, bounds):
        "Determine if a specified region intersects the update region"
        return does_disjoint_intersect_coordinates(
            self._update_region, bounds_to_coordinates(bounds))

    def _paint(self, event=None):
        """ This method is called directly by the UI toolkit's callback
        mechanism on the paint event.
        """
        if self.control is None:
            # the window has gone away, but let the window implementation
            # handle the event as needed
            self._window_paint(event)
            return

        # Create a new GC if necessary
        size = self._get_control_size()
        if (self._size != tuple(size)) or (self._gc is None):
            self._size = tuple(size)
            self._gc = self._create_gc(size)

        # Always give the GC a chance to initialize
        self._init_gc()

        # Layout components and draw
        if hasattr(self.component, "do_layout"):
            self.component.do_layout()
        gc = self._gc
        self.component.draw(gc, view_bounds=(0, 0, size[0], size[1]))

        #        damaged_regions = draw_result['damaged_regions']
        # FIXME: consolidate damaged regions if necessary
        if not self.use_damaged_region:
            self._update_region = None

        # Perform a paint of the GC to the window (only necessary on backends
        # that render to an off-screen buffer)
        self._window_paint(event)

        self._update_region = []
        return

    def __getstate__(self):
        attribs = ("component", "bgcolor", "overlay", "_scroll_origin")
        state = {}
        for attrib in attribs:
            state[attrib] = getattr(self, attrib)
        return state

    #---------------------------------------------------------------------------
    # Wire up the mouse event handlers
    #---------------------------------------------------------------------------

    def _on_left_down(self, event):
        self._handle_mouse_event('left_down', event, set_focus=True)

    def _on_left_up(self, event):
        self._handle_mouse_event('left_up', event)

    def _on_left_dclick(self, event):
        self._handle_mouse_event('left_dclick', event)

    def _on_right_down(self, event):
        self._handle_mouse_event('right_down', event, set_focus=True)

    def _on_right_up(self, event):
        self._handle_mouse_event('right_up', event)

    def _on_right_dclick(self, event):
        self._handle_mouse_event('right_dclick', event)

    def _on_middle_down(self, event):
        self._handle_mouse_event('middle_down', event)

    def _on_middle_up(self, event):
        self._handle_mouse_event('middle_up', event)

    def _on_middle_dclick(self, event):
        self._handle_mouse_event('middle_dclick', event)

    def _on_mouse_move(self, event):
        self._handle_mouse_event('mouse_move', event, 1)

    def _on_mouse_wheel(self, event):
        self._handle_mouse_event('mouse_wheel', event)

    def _on_mouse_enter(self, event):
        self._handle_mouse_event('mouse_enter', event)

    def _on_mouse_leave(self, event):
        self._handle_mouse_event('mouse_leave', event, -1)

    # Additional event handlers that are not part of normal Interactors
    def _on_window_enter(self, event):
        # TODO: implement this to generate a mouse_leave on self.component
        pass

    def _on_window_leave(self, event):
        if self._size is None:
            # PZW: Hack!
            # We need to handle the cases when the window hasn't been painted yet, but
            # it's gotten a mouse event.  In such a case, we just ignore the mouse event.
            # If the window has been painted, then _size will have some sensible value.
            self._prev_event_handler = None
        if self._prev_event_handler:
            mouse_event = self._create_mouse_event(event)
            self._prev_event_handler.dispatch(mouse_event, "mouse_leave")
            self._prev_event_handler = None
        return

    #---------------------------------------------------------------------------
    # Wire up the keyboard event handlers
    #---------------------------------------------------------------------------

    def _on_key_pressed(self, event):
        self._handle_key_event('key_pressed', event)

    def _on_key_released(self, event):
        self._handle_key_event('key_released', event)

    def _on_character(self, event):
        self._handle_key_event('character', event)
コード例 #2
0
class AramisView3D(HasTraits):
    '''This class manages 3D views for AramisCDT variables
    '''

    aramis_data = Instance(AramisFieldData)

    aramis_cdt = Instance(AramisCDT)

    scene = Instance(MlabSceneModel)

    plot_title = Bool(True)

    plot3d_var = Trait('07 d_ux [-]', {'01 x_arr [mm]':['aramis_data', 'x_arr_0'],
                                            '02 y_arr [mm]':['aramis_data', 'y_arr_0'],
                                            '03 z_arr [mm]':['aramis_data', 'z_arr_0'],
                                            '04 ux_arr [mm]':['aramis_data', 'ux_arr'],
                                            '05 uy_arr [mm]':['aramis_data', 'uy_arr'],
                                            '06 uz_arr [mm]':['aramis_data', 'uz_arr'],
                                            '07 d_ux [-]':['aramis_data', 'd_ux'],
                                            '08 crack_filed_arr [mm]': ['aramis_cdt', 'crack_field_arr'],
                                            '09 delta_ux_arr [mm]': ['aramis_data', 'delta_ux_arr'],
                                            '10 delta_uy_arr [mm]': ['aramis_data', 'delta_uy_arr']
                                            })

    plot3d_points_flat = Button
    def _plot3d_points_flat_fired(self):
        '''Plot array of variable using colormap
        '''
        # m.figure(fgcolor=(0, 0, 0), bgcolor=(1, 1, 1) , size=(900, 600))
        # engine = m.get_engine()
        # scene = engine.scenes[0]
        self.scene.mlab.clf()
        m = self.scene.mlab
        m.fgcolor = (0, 0, 0)
        m.bgcolor = (1, 1, 1)
        self.scene.scene.disable_render = True

        plot3d_var = getattr(getattr(self, self.plot3d_var_[0]), self.plot3d_var_[1])

        mask = np.logical_or(np.isnan(self.aramis_data.x_arr_0),
                             self.aramis_data.x_0_mask[0, :, :])
        mask = None
        m.points3d(self.aramis_data.x_arr_0[mask],
                   self.aramis_data.y_arr_0[mask],
                   self.aramis_data.z_arr_0[mask],
                   plot3d_var[mask],
                   mode='cube',
                   scale_mode='none', scale_factor=1)
        m.view(0, 0)
        self.scene.scene.parallel_projection = True
        self.scene.scene.disable_render = False

        if self.plot_title:
            m.title('step no. %d' % self.aramis_data.current_step, size=0.3)

        m.scalarbar(orientation='horizontal', title=self.plot3d_var_[1])

        # plot axes
        m.axes()

    glyph_x_length = Float(0.200)
    glyph_y_length = Float(0.200)
    glyph_z_length = Float(0.000)

    glyph_x_length_cr = Float(3.000)
    glyph_y_length_cr = Float(0.120)
    glyph_z_length_cr = Float(0.120)

    warp_factor = Float(0.0)

    plot3d_points = Button
    def _plot3d_points_fired(self):
        '''Plot arrays of variables in 3d relief
        '''
        aramis_cdt = self.aramis_cdt
        # m.figure(fgcolor=(0, 0, 0), bgcolor=(1, 1, 1), size=(900, 600))

        #
        # scene = engine.scenes[0]
        self.scene.mlab.clf()
        m = self.scene.mlab
        m.fgcolor = (0, 0, 0)
        m.bgcolor = (1, 1, 1)
        self.scene.scene.disable_render = True

        #-----------------------------------
        # plot crack width ('crack_field_w')
        #-----------------------------------

        z_arr = np.zeros_like(self.aramis_data.z_arr_0)

        plot3d_var = getattr(getattr(self, self.plot3d_var_[0]), self.plot3d_var_[1])
        m.points3d(z_arr, self.aramis_data.x_arr_0, self.aramis_data.y_arr_0, plot3d_var,
                   mode='cube', colormap="blue-red", scale_mode='scalar')

        # scale glyphs
        #
        glyph = self.scene.engine.scenes[0].children[0].children[0].children[0]
        glyph.glyph.glyph_source.glyph_position = 'tail'
        glyph.glyph.glyph_source.glyph_source.x_length = self.glyph_x_length_cr
        glyph.glyph.glyph_source.glyph_source.y_length = self.glyph_y_length_cr
        glyph.glyph.glyph_source.glyph_source.z_length = self.glyph_z_length_cr

        #-----------------------------------
        # plot displacement jumps ('d_ux_w')
        #-----------------------------------

        plot3d_var = getattr(getattr(self, self.plot3d_var_[0]), self.plot3d_var_[1])
        m.points3d(z_arr, self.aramis_data.x_arr_0, self.aramis_data.y_arr_0, plot3d_var, mode='cube',
                   colormap="blue-red", scale_mode='none')

        glyph1 = self.scene.engine.scenes[0].children[1].children[0].children[0]
#       # switch order of the scale_factor corresponding to the order of the
        glyph1.glyph.glyph_source.glyph_source.x_length = self.glyph_z_length
        glyph1.glyph.glyph_source.glyph_source.y_length = self.glyph_x_length
        glyph1.glyph.glyph_source.glyph_source.z_length = self.glyph_y_length

        # rotate scene
        #
        # scene = engine.scenes[0]
        self.scene.scene.parallel_projection = True
        m.view(0, 90)

        glyph.glyph.glyph_source.glyph_position = 'head'
        glyph.glyph.glyph_source.glyph_position = 'tail'

        module_manager = self.scene.engine.scenes[0].children[1].children[0]
        module_manager.scalar_lut_manager.show_scalar_bar = True
        module_manager.scalar_lut_manager.show_legend = True
        module_manager.scalar_lut_manager.scalar_bar.orientation = 'horizontal'
        module_manager.scalar_lut_manager.scalar_bar.title = self.plot3d_var_[1]
        module_manager.scalar_lut_manager.scalar_bar_representation.position = (0.10, 0.05)
        module_manager.scalar_lut_manager.scalar_bar_representation.position2 = (0.8, 0.15)
        self.scene.scene.disable_render = False

        if self.plot_title:
            m.title('step no. %d' % self.aramis_data.current_step, size=0.3)

        # m.scalarbar(orientation='horizontal', title=self.plot3d_var_[1])

        # plot axes
        #
        m.axes()

    plot3d_cracks = Button
    def _plot3d_cracks_fired(self):
        '''Plot cracks in 3D
        '''
        aramis_cdt = self.aramis_cdt
        # m.figure(fgcolor=(0, 0, 0), bgcolor=(1, 1, 1), size=(900, 600))

        # engine = m.get_engine()
        # scene = engine.scenes[0]
        self.scene.mlab.clf()
        m = self.scene.mlab
        m.fgcolor = (0, 0, 0)
        m.bgcolor = (1, 1, 1)
        self.scene.scene.disable_render = True

        #-----------------------------------
        # plot crack width ('crack_field_w')
        #-----------------------------------

        z_arr = np.zeros_like(self.aramis_data.z_arr_0)

        plot3d_var = aramis_cdt.crack_field_arr

        m.points3d(z_arr,
                   self.aramis_data.x_arr_0 + self.aramis_data.ux_arr * self.warp_factor,
                   self.aramis_data.y_arr_0 + self.aramis_data.uy_arr * self.warp_factor,
                   plot3d_var,
                   mode='cube', colormap="blue-red", scale_mode='scalar', scale_factor=1.0)

        # scale glyphs
        #
        glyph = self.scene.engine.scenes[0].children[0].children[0].children[0]
        glyph.glyph.glyph_source.glyph_position = 'tail'
        glyph.glyph.glyph_source.glyph_source.x_length = self.glyph_x_length_cr
        glyph.glyph.glyph_source.glyph_source.y_length = self.glyph_y_length_cr
        glyph.glyph.glyph_source.glyph_source.z_length = self.glyph_z_length_cr

        #-----------------------------------
        # plot crack_field_arr
        #-----------------------------------

        m.points3d(z_arr,
                   self.aramis_data.x_arr_0 + self.aramis_data.ux_arr * self.warp_factor,
                   self.aramis_data.y_arr_0 + self.aramis_data.uy_arr * self.warp_factor,
                   plot3d_var,
                   mode='cube', colormap="blue-red", scale_mode='none', scale_factor=1)

        glyph1 = self.scene.engine.scenes[0].children[1].children[0].children[0]
#       # switch order of the scale_factor corresponding to the order of the
        glyph1.glyph.glyph_source.glyph_source.x_length = self.glyph_z_length
        glyph1.glyph.glyph_source.glyph_source.y_length = self.glyph_x_length
        glyph1.glyph.glyph_source.glyph_source.z_length = self.glyph_y_length

        # rotate scene
        #
        scene = self.scene.engine.scenes[0]
        scene.scene.parallel_projection = True
        m.view(0, 90)
        glyph.glyph.glyph_source.glyph_position = 'head'
        glyph.glyph.glyph_source.glyph_position = 'tail'

        module_manager = self.scene.engine.scenes[0].children[1].children[0]
        module_manager.scalar_lut_manager.show_scalar_bar = True
        module_manager.scalar_lut_manager.show_legend = True
        module_manager.scalar_lut_manager.scalar_bar.orientation = 'horizontal'
        module_manager.scalar_lut_manager.scalar_bar.title = 'delta_ux [mm]'
        module_manager.scalar_lut_manager.scalar_bar_representation.position = (0.10, 0.05)
        module_manager.scalar_lut_manager.scalar_bar_representation.position2 = (0.8, 0.15)
        scene.scene.disable_render = False

        if self.plot_title:
            m.title('step no. %d' % self.aramis_data.current_step, size=0.2)

        # set scalar bar to start at zero and format values in font style 'times'
        print 'np.max(plot3d_var)', np.max(plot3d_var)
        wr_max = module_manager.scalar_lut_manager.data_range[1]
#         wr_max = 6.90  # [mm] set fixed ranges
        module_manager.scalar_lut_manager.data_range = np.array([0., wr_max])

        # format scalar bar in font style 'times'
        module_manager.scalar_lut_manager.label_text_property.font_family = 'times'
        module_manager.scalar_lut_manager.label_text_property.italic = False
        module_manager.scalar_lut_manager.label_text_property.bold = False
        module_manager.scalar_lut_manager.label_text_property.font_size = 25

        # title font of scalar bar
        module_manager.scalar_lut_manager.title_text_property.font_family = 'times'
        module_manager.scalar_lut_manager.title_text_property.bold = True
        module_manager.scalar_lut_manager.title_text_property.italic = False
        module_manager.scalar_lut_manager.title_text_property.font_size = 25


        # title font of plot (step no)
        text = scene.children[1].children[0].children[1]
        text.property.font_size = 25
        text.property.font_family = 'times'
        text.property.bold = True
        text.property.italic = False

        # use white background for plots
#         scene.scene.foreground = (0.0, 0.0, 0.0)
#         scene.scene.background = (1.0, 1.0, 1.0)

        # m.scalarbar(orientation='horizontal', title='crack_field')

        # plot axes
        #
        # m.axes()



    plot3d_var_deformed = Button
    def _plot3d_var_deformed_fired(self):
        '''Plot 3D variable in deformed configuration
        '''
        aramis_cdt = self.aramis_cdt

        self.scene.mlab.clf()
        m = self.scene.mlab
#         m.fgcolor = (0, 0, 0)
#         m.bgcolor = (1, 1, 1)
        m.fgcolor = (1, 1, 1)
        m.bgcolor = (0, 0, 0)
        self.scene.scene.disable_render = True

        #-----------------------------------
        # plot displacement jumps)
        #-----------------------------------

        z_arr = np.zeros_like(self.aramis_data.z_arr_0)

        plot3d_var = getattr(getattr(self, self.plot3d_var_[0]), self.plot3d_var_[1])
#         plot3d_var = aramis_cdt.aramis_data.delta_ux_arr

        m.points3d(z_arr,
                   self.aramis_data.x_arr_0 + self.aramis_data.ux_arr * self.warp_factor,
                   self.aramis_data.y_arr_0 + self.aramis_data.uy_arr * self.warp_factor,
                   plot3d_var,
                   mode='cube', colormap="blue-red", scale_mode='scalar', scale_factor=1.0)

        # scale glyphs
        #
        glyph = self.scene.engine.scenes[0].children[0].children[0].children[0]
        glyph.glyph.glyph_source.glyph_position = 'tail'
        glyph.glyph.glyph_source.glyph_source.x_length = self.glyph_x_length_cr
        glyph.glyph.glyph_source.glyph_source.y_length = self.glyph_y_length_cr
        glyph.glyph.glyph_source.glyph_source.z_length = self.glyph_z_length_cr

        #-----------------------------------
        # plot displacement jumps ('delta_ux_arr')
        #-----------------------------------

        m.points3d(z_arr,
                   self.aramis_data.x_arr_0 + self.aramis_data.ux_arr * self.warp_factor,
                   self.aramis_data.y_arr_0 + self.aramis_data.uy_arr * self.warp_factor,
                   plot3d_var,
                   mode='cube', colormap="blue-red", scale_mode='none', scale_factor=1)

        glyph1 = self.scene.engine.scenes[0].children[1].children[0].children[0]
#       # switch order of the scale_factor corresponding to the order of the
        glyph1.glyph.glyph_source.glyph_source.x_length = self.glyph_z_length
        glyph1.glyph.glyph_source.glyph_source.y_length = self.glyph_x_length
        glyph1.glyph.glyph_source.glyph_source.z_length = self.glyph_y_length

        # rotate scene
        #
        scene = self.scene.engine.scenes[0]
        scene.scene.parallel_projection = True
        m.view(0, 90)
        glyph.glyph.glyph_source.glyph_position = 'head'
        glyph.glyph.glyph_source.glyph_position = 'tail'

        module_manager = self.scene.engine.scenes[0].children[1].children[0]
        module_manager.scalar_lut_manager.show_scalar_bar = True
        module_manager.scalar_lut_manager.show_legend = True
        module_manager.scalar_lut_manager.scalar_bar.orientation = 'horizontal'
        module_manager.scalar_lut_manager.scalar_bar.title = self.plot3d_var_[1]
        scene.scene.disable_render = False

        module_manager.scalar_lut_manager.scalar_bar_representation.position2 = np.array([ 0.7, 0.15])
        module_manager.scalar_lut_manager.scalar_bar_representation.position = np.array([ 0.2, 0.1])

        if self.plot_title:
            m.title('step no. %d' % self.aramis_data.current_step, size=0.2)

        # for scalar bar format values in font style 'times'
        module_manager.scalar_lut_manager.label_text_property.font_family = 'times'
        module_manager.scalar_lut_manager.label_text_property.italic = False
        module_manager.scalar_lut_manager.label_text_property.bold = False

        # title font of scalar bar
        module_manager.scalar_lut_manager.title_text_property.font_family = 'times'
        module_manager.scalar_lut_manager.title_text_property.bold = True
        module_manager.scalar_lut_manager.title_text_property.italic = False

        # title font of plot (step no)
        text = scene.children[1].children[0].children[1]
        text.property.font_size = 25
        text.property.font_family = 'times'
        text.property.bold = True
        text.property.italic = True


    clean_scene = Button
    def _clean_scene_fired(self):
        self.scene.mlab.clf()


    view = View(
                Item('plot3d_var'),
                UItem('plot3d_points_flat'),
                UItem('plot3d_points'),
                UItem('plot3d_cracks'),
                UItem('plot3d_var_deformed'),
                Item('_'),
                Item('warp_factor'),
                UItem('clean_scene'),
                id='aramisCDT.view3d',
               )
コード例 #3
0
class ListUndoItem(AbstractUndoItem):
    """ A change to a list, which can be undone.
    """
    #---------------------------------------------------------------------------
    #  Trait definitions:
    #---------------------------------------------------------------------------

    # Object that the change occurred on
    object = Trait(HasTraits)
    # Name of the trait that changed
    name = Str
    # Starting index
    index = Int
    # Items added to the list
    added = List
    # Items removed from the list
    removed = List

    #---------------------------------------------------------------------------
    #  Undoes the change:
    #---------------------------------------------------------------------------

    def undo(self):
        """ Undoes the change.
        """
        try:
            list = getattr(self.object, self.name)
            list[self.index:(self.index + len(self.added))] = self.removed
        except:
            pass

    #---------------------------------------------------------------------------
    #  Re-does the change:
    #---------------------------------------------------------------------------

    def redo(self):
        """ Re-does the change.
        """
        try:
            list = getattr(self.object, self.name)
            list[self.index:(self.index + len(self.removed))] = self.added
        except:
            pass

    #---------------------------------------------------------------------------
    #  Merges two undo items if possible:
    #---------------------------------------------------------------------------

    def merge_undo(self, undo_item):
        """ Merges two undo items if possible.
        """
        # Discard undo items that are identical to us. This is to eliminate
        # the same undo item being created by multiple listeners monitoring the
        # same list for changes:
        if (isinstance(undo_item, self.__class__)
                and (self.object is undo_item.object)
                and (self.name == undo_item.name)
                and (self.index == undo_item.index)):
            added = undo_item.added
            removed = undo_item.removed
            if ((len(self.added) == len(added))
                    and (len(self.removed) == len(removed))):
                for i, item in enumerate(self.added):
                    if item is not added[i]:
                        break
                else:
                    for i, item in enumerate(self.removed):
                        if item is not removed[i]:
                            break
                    else:
                        return True
        return False

    #---------------------------------------------------------------------------
    #  Returns a 'pretty print' form of the object:
    #---------------------------------------------------------------------------

    def __repr__(self):
        """ Returns a 'pretty print' form of the object.
        """
        return 'undo( %s.%s[%d:%d] = %s )' % (
            self.object.__class__.__name__, self.name, self.index,
            self.index + len(self.removed), self.added)
コード例 #4
0
    Tuple,
    Range,
    Trait,
)
from apptools.scripting.recorder import Recorder
from apptools.scripting.recordable import recordable
from apptools.scripting.package_globals import set_recorder

try:
    # Require Traits >= 6.1
    from traits.api import PrefixMap
except ImportError:
    from traits.api import TraitPrefixMap
    representation_trait = Trait(
        "surface", TraitPrefixMap({
            "surface": 2,
            "wireframe": 1,
            "points": 0
        }))
else:
    representation_trait = PrefixMap(
        {
            "surface": 2,
            "wireframe": 1,
            "points": 0
        }, default_value="surface")


######################################################################
# Test classes.
class Property(HasStrictTraits):
    color = Tuple(Range(0.0, 1.0), Range(0.0, 1.0), Range(0.0, 1.0))
コード例 #5
0
ファイル: browser.py プロジェクト: laurentgrenier/mathematics
class PipelineBrowser(HasTraits):
    # The tree generator to use.
    tree_generator = Trait(FullTreeGenerator(), Instance(TreeGenerator))

    # The TVTK render window(s) associated with this browser.
    renwins = List

    # The root object to view in the pipeline.  If None (default), the
    # root object is the render_window of the Scene instance passed at
    # object instantiation time.
    root_object = List(TVTKBase)

    selected = Instance(TVTKBase)

    # This is fired when an object has been changed on the UI. Use this when
    # you do not set the renwins list with the render windows but wish to do
    # your own thing when the object traits are edited on the UI.
    object_edited = Event

    # Private traits.
    # The root of the tree to display.
    _root = Any
    _ui = Any

    ###########################################################################
    # `object` interface.
    ###########################################################################
    def __init__(self, renwin=None, **traits):
        """Initializes the object.

        Parameters
        ----------

        - renwin: `Scene` instance.  Defaults to None.

          This may be passed in addition to the renwins attribute
          which can be a list of scenes.

        """
        super(PipelineBrowser, self).__init__(**traits)
        self._ui = None
        self.view = None
        if renwin:
            self.renwins.append(renwin)

        self._root_object_changed(self.root_object)

    def default_traits_view(self):
        menu = Menu(Action(name='Refresh', action='editor.update_editor'),
                    Action(name='Expand all', action='editor.expand_all'))
        self.menu = menu

        nodes = self.tree_generator.get_nodes(menu)

        self.tree_editor = TreeEditor(nodes=nodes,
                                      editable=False,
                                      orientation='vertical',
                                      hide_root=True,
                                      on_select=self._on_select,
                                      on_dclick=self._on_dclick)
        view = View(Group(
            VSplit(Item(name='_root', editor=self.tree_editor, resizable=True),
                   Item(name='selected', style='custom', resizable=True),
                   show_labels=False,
                   show_border=False)),
                    title='Pipeline browser',
                    help=False,
                    resizable=True,
                    undo=False,
                    revert=False,
                    width=.3,
                    height=.3)
        return view

    ###########################################################################
    # `PipelineBrowser` interface.
    ###########################################################################
    def show(self, parent=None):
        """Show the tree view if not already show.  If optional
        `parent` widget is passed, the tree is displayed inside the
        passed parent widget."""
        # If UI already exists, raise it and return.
        if self._ui and self._ui.control:
            try:
                self._ui.control.Raise()
            except AttributeError:
                pass
            else:
                return
        else:
            # No active ui, create one.
            view = self.default_traits_view()
            if parent:
                self._ui = view.ui(self, parent=parent, kind='subpanel')
            else:
                self._ui = view.ui(self, parent=parent)

    def update(self):
        """Update the tree view."""
        # This is a hack.
        if self._ui and self._ui.control:
            try:
                ed = self._ui._editors[0]
                ed.update_editor()
                self._ui.control.Refresh()
            except (AttributeError, IndexError):
                pass

    # Another name for update.
    refresh = update

    def render(self):
        """Calls render on all render windows associated with this
        browser."""
        self.object_edited = True
        for rw in self.renwins:
            rw.render()

    ###########################################################################
    # Non-public interface.
    ###########################################################################
    def _make_default_root(self):
        tree_gen = self.tree_generator
        objs = [x.render_window for x in self.renwins]
        node = TVTKCollectionNode(object=objs,
                                  name="Root",
                                  tree_generator=tree_gen)
        return node

    def _tree_generator_changed(self, tree_gen):
        """Traits event handler."""
        if self._root:
            root_obj = self._root.object
        else:
            root_obj = self.root_object
        if root_obj:
            ro = root_obj
            if not hasattr(root_obj, '__len__'):
                ro = [root_obj]

            self._root = TVTKCollectionNode(object=ro,
                                            name="Root",
                                            tree_generator=tree_gen)
        else:
            self._root = self._make_default_root()

        self.tree_editor.nodes = tree_gen.get_nodes(self.menu)
        self.update()

    def _root_object_changed(self, root_obj):
        """Trait handler called when the root object is assigned to."""
        tg = self.tree_generator
        if root_obj:
            self._root = TVTKCollectionNode(object=root_obj,
                                            name="Root",
                                            tree_generator=tg)
        else:
            self._root = self._make_default_root()
            self.root_object = self._root.object
        self.update()

    def _root_object_items_changed(self, list_event):
        """Trait handler called when the items of the list change."""
        self._root_object_changed(self.root_object)

    def _on_dclick(self, obj):
        """Callback that is called when nodes are double-clicked."""
        if hasattr(obj, 'object') and hasattr(obj.object, 'edit_traits'):
            object = obj.object
            view = object.trait_view()
            view.handler = UICloseHandler(browser=self)
            object.on_trait_change(self.render)
            ui = object.edit_traits(view=view)

    def _on_select(self, obj):
        if hasattr(obj, 'object') and hasattr(obj.object, 'edit_traits'):
            new = obj.object
            old = self.selected
            if new != old:
                self.selected = new
            if old is not None:
                old.on_trait_change(self.render, remove=True)
            if new is not None:
                new.on_trait_change(self.render)
コード例 #6
0
class Bar(HasTraits):
    s = Trait("", MyHandler())
コード例 #7
0
ファイル: helper_functions.py プロジェクト: victorliun/mayavi
class Mesh(Pipeline):
    """
    Plots a surface using grid-spaced data supplied as 2D arrays.

    **Function signatures**::

        mesh(x, y, z, ...)

    x, y, z are 2D arrays, all of the same shape, giving the positions of
    the vertices of the surface. The connectivity between these points is
    implied by the connectivity on the arrays.

    For simple structures (such as orthogonal grids) prefer the `surf`
    function, as it will create more efficient data structures. For mesh
    defined by triangles rather than regular implicit connectivity, see the
    `triangular_mesh` function.
    """

    scale_mode = Trait('none', {
        'none': 'data_scaling_off',
        'scalar': 'scale_by_scalar',
        'vector': 'scale_by_vector'
    },
                       help="""the scaling mode for the glyphs
                            ('vector', 'scalar', or 'none').""")

    scale_factor = CFloat(0.05,
                          desc="""scale factor of the glyphs used to represent
                        the vertices, in fancy_mesh mode. """)

    tube_radius = Trait(0.025,
                        CFloat,
                        None,
                        help="""radius of the tubes used to represent the
                        lines, in mesh mode. If None, simple lines are used.
                        """)

    scalars = Array(help="""optional scalar data.""")

    mask = Array(help="""boolean mask array to suppress some data points.
                 Note: this works based on colormapping of scalars and will
                 not work if you specify a solid color using the
                 `color` keyword.""")

    representation = Trait(
        'surface',
        'wireframe',
        'points',
        'mesh',
        'fancymesh',
        desc="""the representation type used for the surface.""")

    _source_function = Callable(grid_source)

    _pipeline = [
        ExtractEdgesFactory, GlyphFactory, TubeFactory, SurfaceFactory
    ]

    def __call_internal__(self, *args, **kwargs):
        """ Override the call to be able to choose whether to apply
        filters.
        """
        self.source = self._source_function(*args, **kwargs)
        kwargs.pop('name', None)
        self.store_kwargs(kwargs)
        # Copy the pipeline so as not to modify it for the next call
        self.pipeline = self._pipeline[:]
        if not self.kwargs['representation'] in ('mesh', 'fancymesh'):
            self.pipeline.remove(ExtractEdgesFactory)
            self.pipeline.remove(TubeFactory)
            self.pipeline.remove(GlyphFactory)
            self.pipeline = [
                PolyDataNormalsFactory,
            ] + self.pipeline
        else:
            if self.kwargs['tube_radius'] == None:
                self.pipeline.remove(TubeFactory)
            if not self.kwargs['representation'] == 'fancymesh':
                self.pipeline.remove(GlyphFactory)
            self.kwargs['representation'] = 'surface'
        return self.build_pipeline()
コード例 #8
0
ファイル: mlab.py プロジェクト: victorliun/mayavi
class LUTBase(MLabBase):
    # The choices for the lookuptable
    lut_type = Trait('red-blue', 'red-blue', 'blue-red',
                     'black-white', 'white-black',
                     desc='the type of the lookup table')

    # The LookupTable instance.
    lut = Instance(tvtk.LookupTable, ())

    # The scalar bar.
    scalar_bar = Instance(tvtk.ScalarBarActor, (),
                          {'orientation':'horizontal',
                           'width':0.8, 'height':0.17})

    # The scalar_bar widget.
    scalar_bar_widget = Instance(tvtk.ScalarBarWidget, ())

    # The legend name for the scalar bar.
    legend_text = Str('Scalar', desc='the title of the legend')

    # Turn on/off the visibility of the scalar bar.
    show_scalar_bar = Bool(False,
                           desc='specifies if scalar bar is shown or not')

    def __init__(self, **traits):
        super(LUTBase, self).__init__(**traits)
        self.lut.number_of_colors = 256
        self._lut_type_changed(self.lut_type)
        self.scalar_bar.set(lookup_table=self.lut,
                            title=self.legend_text)
        pc = self.scalar_bar.position_coordinate
        pc.coordinate_system = 'normalized_viewport'
        pc.value = 0.1, 0.01, 0.0
        self.scalar_bar_widget.set(scalar_bar_actor=self.scalar_bar,
                                   key_press_activation=False)

    def _lut_type_changed(self, val):
        if val == 'red-blue':
            hue_range = 0.0, 0.6667
            saturation_range = 1.0, 1.0
            value_range = 1.0, 1.0
        elif val == 'blue-red':
            hue_range = 0.6667, 0.0
            saturation_range = 1.0, 1.0
            value_range = 1.0, 1.0
        elif val == 'black-white':
            hue_range = 0.0, 0.0
            saturation_range = 0.0, 0.0
            value_range = 0.0, 1.0
        elif val == 'white-black':
            hue_range = 0.0, 0.0
            saturation_range = 0.0, 0.0
            value_range = 1.0, 0.0
        lut = self.lut
        lut.set(hue_range=hue_range, saturation_range=saturation_range,
                value_range=value_range, number_of_table_values=256,
                ramp='sqrt')
        lut.force_build()

        self.render()

    def _legend_text_changed(self, val):
        self.scalar_bar.title = val
        self.scalar_bar.modified()
        self.render()

    def _show_scalar_bar_changed(self, val):
        if self.renwin:
            self.scalar_bar_widget.enabled = val
            self.renwin.render()

    def _renwin_changed(self, old, new):
        sbw = self.scalar_bar_widget
        if old:
            sbw.interactor = None
            old.render()
        if new:
            sbw.interactor = new.interactor
            sbw.enabled = self.show_scalar_bar
            new.render()
        super(LUTBase, self)._renwin_changed(old, new)
コード例 #9
0
ファイル: array_source.py プロジェクト: zyex1108/mayavi
class ArraySource(Source):

    """A simple source that allows one to view a suitably shaped numpy
    array as ImageData.  This supports both scalar and vector data.
    """

    # The scalar array data we manage.
    scalar_data = Trait(None, _check_scalar_array, rich_compare=False)

    # The name of our scalar array.
    scalar_name = Str('scalar')

    # The vector array data we manage.
    vector_data = Trait(None, _check_vector_array, rich_compare=False)

    # The name of our vector array.
    vector_name = Str('vector')

    # The spacing of the points in the array.
    spacing = DelegatesTo('change_information_filter', 'output_spacing',
                          desc='the spacing between points in array')

    # The origin of the points in the array.
    origin = DelegatesTo('change_information_filter', 'output_origin',
                         desc='the origin of the points in array')

    # Fire an event to update the spacing and origin. This
    # is here for backwards compatability. Firing this is no
    # longer needed.
    update_image_data = Button('Update spacing and origin')

    # The image data stored by this instance.
    image_data = Instance(tvtk.ImageData, (), allow_none=False)

    # Use an ImageChangeInformation filter to reliably set the
    # spacing and origin on the output
    change_information_filter = Instance(tvtk.ImageChangeInformation, args=(),
                                         kw={'output_spacing' : (1.0, 1.0, 1.0),
                                             'output_origin' : (0.0, 0.0, 0.0)})

    # Should we transpose the input data or not.  Transposing is
    # necessary to make the numpy array compatible with the way VTK
    # needs it.  However, transposing numpy arrays makes them
    # non-contiguous where the data is copied by VTK.  Thus, when the
    # user explicitly requests that transpose_input_array is false
    # then we assume that the array has already been suitably
    # formatted by the user.
    transpose_input_array = Bool(True, desc='if input array should be transposed (if on VTK will copy the input data)')

    # Information about what this object can produce.
    output_info = PipelineInfo(datasets=['image_data'])

    # Specify the order of dimensions. The default is: [0, 1, 2]
    dimensions_order = List(Int, [0, 1, 2])

    # Our view.
    view = View(Group(Item(name='transpose_input_array'),
                      Item(name='scalar_name'),
                      Item(name='vector_name'),
                      Item(name='spacing'),
                      Item(name='origin'),
                      show_labels=True)
                )

    ######################################################################
    # `object` interface.
    ######################################################################
    def __init__(self, **traits):
        # Set the scalar and vector data at the end so we pop it here.
        sd = traits.pop('scalar_data', None)
        vd = traits.pop('vector_data', None)
        # Now set the other traits.
        super(ArraySource, self).__init__(**traits)
        self.configure_input_data(self.change_information_filter,
                                  self.image_data)

        # And finally set the scalar and vector data.
        if sd is not None:
            self.scalar_data = sd
        if vd is not None:
            self.vector_data = vd

        self.outputs = [ self.change_information_filter.output ]
        self.on_trait_change(self._information_changed, 'spacing,origin')

    def __get_pure_state__(self):
        d = super(ArraySource, self).__get_pure_state__()
        d.pop('image_data', None)
        return d

    ######################################################################
    # ArraySource interface.
    ######################################################################
    def update(self):
        """Call this function when you change the array data
        in-place."""
        d = self.image_data
        d.modified()
        pd = d.point_data
        if self.scalar_data is not None:
            pd.scalars.modified()
        if self.vector_data is not None:
            pd.vectors.modified()
        self.change_information_filter.update()
        self.data_changed = True

    ######################################################################
    # Non-public interface.
    ######################################################################

    def _image_data_changed(self, value):
        self.configure_input_data(self.change_information_filter, value)

    def _scalar_data_changed(self, data):
        img_data = self.image_data
        if data is None:
            img_data.point_data.scalars = None
            self.data_changed = True
            return
        dims = list(data.shape)
        if len(dims) == 2:
            dims.append(1)

        # set the dimension indices
        dim0, dim1, dim2 = self.dimensions_order

        img_data.origin = tuple(self.origin)
        img_data.dimensions = tuple(dims)
        img_data.extent = 0, dims[dim0]-1, 0, dims[dim1]-1, 0, dims[dim2]-1
        if is_old_pipeline():
            img_data.update_extent = 0, dims[dim0]-1, 0, dims[dim1]-1, 0, dims[dim2]-1
        else:
            update_extent = [0, dims[dim0]-1, 0, dims[dim1]-1, 0, dims[dim2]-1]
            self.change_information_filter.set_update_extent(update_extent)
        if self.transpose_input_array:
            img_data.point_data.scalars = numpy.ravel(numpy.transpose(data))
        else:
            img_data.point_data.scalars = numpy.ravel(data)
        img_data.point_data.scalars.name = self.scalar_name
        # This is very important and if not done can lead to a segfault!
        typecode = data.dtype
        if is_old_pipeline():
            img_data.scalar_type = array_handler.get_vtk_array_type(typecode)
            img_data.update() # This sets up the extents correctly.
        else:
            filter_out_info = self.change_information_filter.get_output_information(0)
            img_data.set_point_data_active_scalar_info(filter_out_info,
                    array_handler.get_vtk_array_type(typecode), -1)
            img_data.modified()
        img_data.update_traits()
        self.change_information_filter.update()

        # Now flush the mayavi pipeline.
        self.data_changed = True

    def _vector_data_changed(self, data):
        img_data = self.image_data
        if data is None:
            img_data.point_data.vectors = None
            self.data_changed = True
            return
        dims = list(data.shape)
        if len(dims) == 3:
            dims.insert(2, 1)
            data = numpy.reshape(data, dims)

        img_data.origin = tuple(self.origin)
        img_data.dimensions = tuple(dims[:-1])
        img_data.extent = 0, dims[0]-1, 0, dims[1]-1, 0, dims[2]-1
        if is_old_pipeline():
            img_data.update_extent = 0, dims[0]-1, 0, dims[1]-1, 0, dims[2]-1
        else:
            self.change_information_filter.update_information()
            update_extent = [0, dims[0]-1, 0, dims[1]-1, 0, dims[2]-1]
            self.change_information_filter.set_update_extent(update_extent)
        sz = numpy.size(data)
        if self.transpose_input_array:
            data_t = numpy.transpose(data, (2, 1, 0, 3))
        else:
            data_t = data
        img_data.point_data.vectors = numpy.reshape(data_t, (sz/3, 3))
        img_data.point_data.vectors.name = self.vector_name
        if is_old_pipeline():
            img_data.update() # This sets up the extents correctly.
        else:
            img_data.modified()
        img_data.update_traits()
        self.change_information_filter.update()

        # Now flush the mayavi pipeline.
        self.data_changed = True

    def _scalar_name_changed(self, value):
        if self.scalar_data is not None:
            self.image_data.point_data.scalars.name = value
            self.data_changed = True

    def _vector_name_changed(self, value):
        if self.vector_data is not None:
            self.image_data.point_data.vectors.name = value
            self.data_changed = True

    def _transpose_input_array_changed(self, value):
        if self.scalar_data is not None:
            self._scalar_data_changed(self.scalar_data)
        if self.vector_data is not None:
            self._vector_data_changed(self.vector_data)

    def _information_changed(self):
        self.change_information_filter.update()
        self.data_changed = True
コード例 #10
0
class Controller(HasTraits):

    # A reference to the plot viewer object
    viewer = Instance(Viewer)

    # Some parameters controller the random signal that will be generated
    distribution_type = Enum("normal")
    mean = Float(0.0)
    stddev = Float(1.0)

    # The max number of data points to accumulate and show in the plot
    max_num_points = Int(100)

    # The number of data points we have received; we need to keep track of
    # this in order to generate the correct x axis data series.
    num_ticks = Int(0)

    # private reference to the random number generator.  this syntax
    # just means that self._generator should be initialized to
    # random.normal, which is a random number function, and in the future
    # it can be set to any callable object.
    _generator = Trait(np.random.normal, Callable)

    view = View(Group('distribution_type',
                      'mean',
                      'stddev',
                      'max_num_points',
                      orientation="vertical"),
                buttons=["OK", "Cancel"])

    def timer_tick(self, *args):
        """
        Callback function that should get called based on a timer tick.  This
        will generate a new random data point and set it on the `.data` array
        of our viewer object.
        """
        # Generate a new number and increment the tick count
        #x, y, z=accel.read()

        # ADXL345 address, 0x53(83)
        # Select bandwidth rate register, 0x2C(44)
        #		0x0A(10)	Normal mode, Output data rate = 100 Hz
        bus.write_byte_data(0x53, 0x2C, 0x0A)

        # ADXL345 address, 0x53(83)
        # Select power control register, 0x2D(45)
        #		0x08(08)	Auto Sleep disable
        bus.write_byte_data(0x53, 0x2D, 0x08)

        # ADXL345 address, 0x53(83)
        # Select data format register, 0x31(49)
        #		0x08(08)	Self test disabled, 4-wire interface
        #					Full resolution, Range = +/-2g

        bus.write_byte_data(0x53, 0x31, 0x08)
        # time.sleep(0.5)
        # ADXL345 address, 0x53(83)
        # Read data back from 0x32(50), 2 bytes
        # X-Axis LSB, X-Axis MSB

        data0 = bus.read_byte_data(0x53, 0x32)
        data1 = bus.read_byte_data(0x53, 0x33)

        # Convert the data to 10-bits
        xAccl = ((data1 & 0x03) * 256) + data0

        if xAccl > 511:
            xAccl -= 1024
        # ADXL345 address, 0x53(83)
        # Read data back from 0x34(52), 2 bytes
        # Y-Axis LSB, Y-Axis MSB
        data0 = bus.read_byte_data(0x53, 0x34)
        data1 = bus.read_byte_data(0x53, 0x35)

        # Convert the data to 10-bits
        yAccl = ((data1 & 0x03) * 256) + data0

        if yAccl > 511:
            yAccl -= 1024

        # ADXL345 address, 0x53(83)
        # Read data back from 0x36(54), 2 bytes
        # Z-Axis LSB, Z-Axis MSB
        data0 = bus.read_byte_data(0x53, 0x36)
        data1 = bus.read_byte_data(0x53, 0x37)

        # Convert the data to 10-bits
        zAccl = ((data1 & 0x03) * 256) + data0
        if zAccl > 511:
            zAccl -= 1024

        # Output data to screen
        # print "Acceleration in X-Axis : %d" %xAccl
        # print "Acceleration in Y-Axis : %d" %yAccl
        # print "Acceleration in Z-Axis : %d" %zAccl
        if xAccl > 285 or xAccl < 220:
            on = 1
        else:
            on = -1
        new_val = on
        self.num_ticks += 15

        # grab the existing data, truncate it, and append the new point.
        # This isn't the most efficient thing in the world but it works.
        cur_data = self.viewer.data
        new_data = np.hstack((cur_data[-self.max_num_points + 1:], [new_val]))
        new_index = np.arange(self.num_ticks - len(new_data) + 1,
                              self.num_ticks + 0.01)

        self.viewer.index = new_index
        self.viewer.data = new_data
        return

    def _distribution_type_changed(self):
        # This listens for a change in the type of distribution to use.
        while True:
            # Read the X, Y, Z axis acceleration values and print them.
            x, y, z = accel.read()
            print('X={0}, Y={1}, Z={2}'.format(x, y, z))
            # Wait half a second and repeat.
            time.sleep(0.1)
        self._generator = x
コード例 #11
0
class PlotScrollBar(NativeScrollBar):
    """
    A ScrollBar that can be wired up to anything with an xrange or yrange
    and which can be attached to a plot container.
    """

    # The axis corresponding to this scrollbar.
    axis = Enum("index", "value")

    # The renderer or Plot to attach this scrollbar to.  By default, this
    # is just self.component.
    plot = Property

    # The mapper for associated with the scrollbar. By default, this is the
    # mapper on **plot** that corresponds to **axis**.
    mapper = Property

    #------------------------------------------------------------------------
    # Private traits
    #------------------------------------------------------------------------

    # The value of the override plot to use, if any.  If None, then uses
    # self.component.
    _plot = Trait(None, Any)

    # The value of the override mapper to use, if any.  If None, then uses the
    # mapper on self.component.
    _mapper = Trait(None, Any)

    # Stores the index (0 or 1) corresponding to self.axis
    _axis_index = Trait(None, None, Int)

    #----------------------------------------------------------------------
    # Public methods
    #----------------------------------------------------------------------

    def force_data_update(self):
        """ This forces the scrollbar to recompute its range bounds.  This
        should be used if datasources are changed out on the range, or if
        the data ranges on existing datasources of the range are changed.
        """
        self._handle_dataspace_update()

    def overlay(self, component, gc, view_bounds=None, mode="default"):
        self.do_layout()
        self._draw_mainlayer(gc, view_bounds, "default")

    def _draw_plot(self, gc, view_bounds=None, mode="default"):
        self._draw_mainlayer(gc, view_bounds, "default")

    def _do_layout(self):
        if getattr(self.plot, "_layout_needed", False):
            self.plot.do_layout()
        axis = self._determine_axis()
        low, high = self.mapper.screen_bounds
        self.bounds[axis] = high - low
        self.position[axis] = low
        self._widget_moved = True

    def _get_abs_coords(self, x, y):
        if self.container is not None:
            return self.container.get_absolute_coords(x, y)
        else:
            return self.component.get_absolute_coords(x, y)

    #----------------------------------------------------------------------
    # Scrollbar
    #----------------------------------------------------------------------

    def _handle_dataspace_update(self):
        # This method reponds to changes from the dataspace side, e.g.
        # a change in the range bounds or the data bounds of the datasource.

        # Get the current datasource bounds
        range = self.mapper.range
        bounds_list = [source.get_bounds() for source in range.sources \
                       if source.get_size() > 0]
        mins, maxes = zip(*bounds_list)
        dmin = min(mins)
        dmax = max(maxes)

        view = float(range.high - range.low)

        # Take into account the range's current low/high and the data bounds
        # to compute the total range
        totalmin = min(range.low, dmin)
        totalmax = max(range.high, dmax)

        # Compute the size available for the scrollbar to scroll in
        scrollrange = (totalmax - totalmin) - view
        if round(scrollrange / 20.0) > 0.0:
            ticksize = scrollrange / round(scrollrange / 20.0)
        else:
            ticksize = 1
        foo = (totalmin, totalmax, view, ticksize)
        print("scrollrange:", foo)
        self.trait_setq(range=foo,
                        scroll_position=max(
                            min(self.scroll_position, totalmax - view),
                            totalmin))
        self._scroll_updated = True
        self.request_redraw()
        return

    def _scroll_position_changed(self):
        super(PlotScrollBar, self)._scroll_position_changed()

        # Notify our range that we've changed
        range = self.mapper.range
        view_width = range.high - range.low
        new_scroll_pos = self.scroll_position
        range.set_bounds(new_scroll_pos, new_scroll_pos + view_width)
        return

    #----------------------------------------------------------------------
    # Event listeners
    #----------------------------------------------------------------------

    def _component_changed(self, old, new):
        # Check to see if we're currently overriding the value of self.component
        # in self.plot.  If so, then don't change the event listeners.
        if self._plot is not None:
            return
        if old is not None:
            self._modify_plot_listeners(old, "detach")
        if new is not None:
            self._modify_plot_listeners(new, "attach")
            self._update_mapper_listeners()
        return

    def __plot_changed(self, old, new):
        if old is not None:
            self._modify_plot_listeners(old, "detach")
        elif self.component is not None:
            # Remove listeners from self.component, if it exists
            self._modify_plot_listeners(self.component, "detach")
        if new is not None:
            self._modify_plot_listeners(new, "attach")
            self._update_mapper_listeners()
        elif self.component is not None:
            self._modify_plot_listeners(self.component, "attach")
            self._update_mapper_listeners()
        return

    def _modify_plot_listeners(self, plot, action="attach"):
        if action == "attach":
            remove = False
        else:
            remove = True
        plot.on_trait_change(self._component_bounds_handler,
                             "bounds",
                             remove=remove)
        plot.on_trait_change(self._component_bounds_handler,
                             "bounds_items",
                             remove=remove)
        plot.on_trait_change(self._component_pos_handler,
                             "position",
                             remove=remove)
        plot.on_trait_change(self._component_pos_handler,
                             "position_items",
                             remove=remove)
        return

    def _component_bounds_handler(self):
        self._handle_dataspace_update()
        self._widget_moved = True

    def _component_pos_handler(self):
        self._handle_dataspace_update()
        self._widget_moved = True

    def _update_mapper_listeners(self):
        #if self._mapper
        pass

    def _handle_mapper_updated(self):
        self._handle_dataspace_update()

    #------------------------------------------------------------------------
    # Property getter/setters
    #------------------------------------------------------------------------

    def _get_plot(self):
        if self._plot is not None:
            return self._plot
        else:
            return self.component

    def _set_plot(self, val):
        self._plot = val
        return

    def _get_mapper(self):
        if self._mapper is not None:
            return self._mapper
        else:
            return getattr(self.plot, self.axis + "_mapper")

    def _set_mapper(self, new_mapper):
        self._mapper = new_mapper
        return

    def _get_axis_index(self):
        if self._axis_index is None:
            return self._determine_axis()
        else:
            return self._axis_index

    def _set_axis_index(self, val):
        self._axis_index = val
        return

    #------------------------------------------------------------------------
    # Private methods
    #------------------------------------------------------------------------

    def _get_axis_coord(self, event, axis="index"):
        """ Returns the coordinate of the event along the axis of interest
        to this tool (or along the orthogonal axis, if axis="value").
        """
        event_pos = (event.x, event.y)
        if axis == "index":
            return event_pos[self.axis_index]
        else:
            return event_pos[1 - self.axis_index]

    def _determine_axis(self):
        """ Determines whether the index of the coordinate along this tool's
        axis of interest is the first or second element of an (x,y) coordinate
        tuple.

        This method is only called if self._axis_index hasn't been set (or is
        None).
        """
        if self.axis == "index":
            if self.plot.orientation == "h":
                return 0
            else:
                return 1
        else:  # self.axis == "value"
            if self.plot.orientation == "h":
                return 1
            else:
                return 0
コード例 #12
0
class CamMove(HasStrictTraits):
    '''Camera transitions.

    Attach functional mapping depending on time variable
    for azimuth, elevation, distance, focal point and roll angle.
    '''

    def __init__(self, *args, **kw):
        super(CamMove, self).__init__(*args, **kw)

    fta = WeakRef
    ftv = WeakRef
    from_station = WeakRef(CamStation)
    to_station = WeakRef(CamStation)

    changed = Event

    cam_attributes = [
        'azimuth', 'elevation', 'distance', 'fpoint', 'roll']

    azimuth_move = Trait('linear', {'linear': linear_cam_move,
                                    'damped': damped_cam_move})
    elevation_move = Trait('linear', {'linear': linear_cam_move,
                                      'damped': damped_cam_move})
    distance_move = Trait('linear', {'linear': linear_cam_move,
                                     'damped': damped_cam_move})
    fpoint_move = Trait('linear', {'linear': linear_cam_move,
                                   'damped': damped_cam_move})
    roll_move = Trait('linear', {'linear': linear_cam_move,
                                 'damped': damped_cam_move})

    duration = Float(10, label='Duration')

    bts = Property(label='Start time')

    def _get_bts(self):
        return self.from_station.time_stemp

    ets = Property(label='End time')

    def _get_ets(self):
        return self.from_station.time_stemp + self.duration

    n_t = Int(10, input=True)

    cmt = Property(Array('float_'), depends_on='n_t')
    '''Relative camera move time (CMT) running from zero to one.
    '''
    @cached_property
    def _get_cmt(self):
        return np.linspace(0, 1, self.n_t)

    viz_t_move = Property(Array('float_'))
    '''Time line range during the camera move
    '''

    def _get_viz_t_move(self):
        return np.linspace(self.bts, self.ets, self.n_t)

    vot_start = Float(0.0, auto_set=False, enter_set=True, input=True)
    vot_end = Float(1.0, auto_set=False, enter_set=True, input=True)
    vot = Property(Array('float_'), depends_on='n_t,vot_start,vot_end')
    '''Visualization object time (VOT). By default it is the same as the camera time.
    It can be mapped to a different time profile using viz_time_fn
    '''
    @cached_property
    def _get_vot(self):
        return np.linspace(self.vot_start, self.vot_end, self.n_t)

    def _get_vis3d_center_t(self):
        '''Get the center of the object'''
        return self.ftv.get_center_t

    def _get_vis3d_bounding_box_t(self):
        '''Get the bounding box of the object'''
        return self.vis.get_center(self.t_range)

    transition_arr = Property(
        Array(dtype='float_'), depends_on='changed,+input')
    '''Array with azimuth values along the transition 
    '''
    @cached_property
    def _get_transition_arr(self):
        trans_arr = [getattr(self, attr + '_move_')(
            getattr(self.from_station, attr),
            getattr(self.to_station, attr), self.n_t)
            for attr in self.cam_attributes
        ]
        trans_arr.append(self.vot)
        trans_arr.append(self.viz_t_move)
        return trans_arr

    def reset_cam(self, m, a, e, d, f, r):
        m.view(azimuth=a, elevation=e, distance=d, focalpoint=f)
        m.roll(r)

    def take(self, ftv):
        for a, e, d, f, r, vot, viz_t in zip(*self.transition_arr):
            ftv.update(vot, viz_t, force=True)
            self.reset_cam(ftv.mlab, a, float(e), d, f, r)
            sleep(self.fta.anim_delay)

    def render_take(self, ftv, fname_base, format_, idx_offset, figsize_factor):
        im_files = []
        for idx, (a, e, d, f, r, vot, viz_t) \
                in enumerate(zip(*self.transition_arr)):
            ftv.update(vot, viz_t, force=True)
            # @todo: temporary focal point determination - make it optional
            self.reset_cam(ftv.mlab, a, e, d, f, r)
            fname = '%s%03d.%s' % (fname_base, idx + idx_offset, format_)
            ftv.mlab.savefig(fname, size=(
                figsize_factor * 800, figsize_factor * 500))  #
            im_files.append(fname)
        return im_files

    view = View(VGroup(HGroup(InstanceUItem('from_station@', resizable=True),
                              VGroup(Item('azimuth_move'),
                                     Item('elevation_move'),
                                     Item('distance_move'),
                                     Item('roll_move'),
                                     Item('fpoint_move'),
                                     Item('n_t'),
                                     Item('duration'),
                                     ),
                              InstanceUItem('to_station@', resizable=True),
                              ),
                       VGroup(HGroup(UItem('vot_start'),
                                     UItem('vot_end'),
                                     springy=True
                                     ),
                              label='object time range'
                              ),
                       ),
                buttons=['OK', 'Cancel'])
コード例 #13
0
class OutputList(HasTraits):
    """This class has methods to emulate an file-like output list of strings.

  The `max_len` attribute specifies the maximum number of bytes saved by
  the object.  `max_len` may be set to None.

  The `paused` attribute is a bool; when True, text written to the
  OutputList is saved in a separate buffer, and the display (if there is
  one) does not update.  When `paused` returns is set to False, the data is
  copied from the paused buffer to the main text string.
  """

    # Holds LogItems to display
    unfiltered_list = List(LogItem)
    # Holds LogItems while self.paused is True.
    _paused_buffer = List(LogItem)
    # filtered set of messages
    filtered_list = List(LogItem)
    # state of fiter on messages
    log_level_filter = Enum(list(SYSLOG_LEVELS.keys()))
    # The maximum allowed length of self.text (and self._paused_buffer).
    max_len = Trait(DEFAULT_MAX_LEN, None, Int)

    # When True, the 'write' or 'write_level' methods append to self._paused_buffer
    # When the value changes from True to False, self._paused_buffer is copied
    # back to self.unfiltered_list.
    paused = Bool(False)

    def __init__(self, tfile=False, outdir=''):
        if tfile:

            self.logfile = sopen(os.path.join(outdir, LOGFILE), 'w')
            self.tfile = True
        else:
            self.tfile = False

    def write(self, s):
        """
    Write to the lists OutputList as STDOUT or STDERR.

    This method exist to allow STDERR and STDOUT to be redirected into this
    display. It should only be called when writing to STDOUT and STDERR.
    Any log levels from this method will be LOG_LEVEL_CONSOLE
    Ignores spaces.

    Parameters
    ----------
    s : str
      string to cast as LogItem and write to tables
    """

        if s and not s.isspace():
            log = LogItem(s, CONSOLE_LOG_LEVEL)
            if self.paused:
                self.append_truncate(self._paused_buffer, log)
            else:
                self.append_truncate(self.unfiltered_list, log)
                if log.matches_log_level_filter(self.log_level_filter):
                    self.append_truncate(self.filtered_list, log)
            if self.tfile:
                self.logfile.write(log.print_to_log())

    def write_level(self, s, level):
        """
    Write to the lists in OutputList from device or user space.

    Parameters
    ----------
    s : str
      string to cast as LogItem and write to tables
    level : int
      Integer log level to use when creating log item.
    """
        if s and not s.isspace():
            log = LogItem(s, level)
            if self.paused:
                self.append_truncate(self._paused_buffer, log)
            else:
                self.append_truncate(self.unfiltered_list, log)
                if log.matches_log_level_filter(self.log_level_filter):
                    self.append_truncate(self.filtered_list, log)

    def append_truncate(self, buffer, s):
        """
    Append to a front of buffer, keeping overall size less than max_len

    Parameters
    ----------
    s : List
      Buffer to append
    s : LogItem
      Log Item to add
    """
        if len(buffer) > self.max_len:
            assert (len(buffer) -
                    self.max_len) == 1, "Output list buffer is too long"
            buffer.pop()
        buffer.insert(0, s)

    def clear(self):
        """
    Clear all Output_list buffers
    """
        self._paused_buffer = []
        self.filtered_list = []
        self.unfiltered_list = []

    def flush(self):
        GUI.process_events()

    def close(self):
        if self.tfile:
            self.logfile.close()

    def _log_level_filter_changed(self):
        """
    Copy items from unfiltered list into filtered list
    """
        self.filtered_list = [item for item in self.unfiltered_list \
                              if item.matches_log_level_filter(self.log_level_filter)]

    def _paused_changed(self):
        """
    Swap buffers around when the paused boolean changes state.
    """
        if self.paused:
            # Copy the current list to _paused_buffer.  While the OutputStream
            # is paused, the write methods will append its argument to _paused_buffer.
            self._paused_buffer = self.unfiltered_list
        else:
            # No longer paused, so copy the _paused_buffer to the displayed list, and
            # reset _paused_buffer.
            self.unfiltered_list = self._paused_buffer
            # we have to refilter the filtered list too
            self._log_level_filter_changed()
            self._paused_buffer = []

    def traits_view(self):
        view = \
          View(
              UItem('filtered_list',
                    editor = TabularEditor(adapter=LogItemOutputListAdapter(), editable=False,
                                           vertical_lines=False, horizontal_lines=False))
            )
        return view
コード例 #14
0
ファイル: signals.py プロジェクト: NormonisPing/acoular
class GenericSignalGenerator(SignalGenerator):
    """
    Generate signal from output of :class:`~acoular.tprocess.SamplesGenerator` object.
    """
    #: Data source; :class:`~acoular.tprocess.SamplesGenerator` or derived object.
    source = Trait(SamplesGenerator)

    #: Sampling frequency of output signal, as given by :attr:`source`.
    sample_freq = Delegate('source')

    _numsamples = CLong(0)

    #: Number of samples to generate. Is set to source.numsamples by default.
    numsamples = Property()

    def _get_numsamples(self):
        if self._numsamples:
            return self._numsamples
        else:
            return self.source.numsamples

    def _set_numsamples(self, numsamples):
        self._numsamples = numsamples

    #: Boolean flag, if 'True' (default), signal track is repeated if requested
    #: :attr:`numsamples` is higher than available sample number
    loop_signal = Bool(True)

    # internal identifier
    digest = Property(
        depends_on = ['source.digest', 'loop_signal', 'numsamples', \
        'rms', '__class__'],
        )

    @cached_property
    def _get_digest(self):
        return digest(self)

    def signal(self):
        """
        Deliver the signal.

        Returns
        -------
        array of floats
            The resulting signal as an array of length :attr:`~GenericSignalGenerator.numsamples`.
        """
        block = 1024
        if self.source.numchannels > 1:
            warn(
                "Signal source has more than one channel. Only channel 0 will be used for signal.",
                Warning,
                stacklevel=2)
        nums = self.numsamples
        track = zeros(nums)

        # iterate through source generator to fill signal track
        for i, temp in enumerate(self.source.result(block)):
            start = block * i
            stop = start + len(temp[:, 0])
            if nums > stop:
                track[start:stop] = temp[:, 0]
            else:  # exit loop preliminarily if wanted signal samples are reached
                track[start:nums] = temp[:nums - start, 0]
                break

        # if the signal should be repeated after finishing and there are still samples open
        if self.loop_signal and (nums > stop):

            # fill up empty track with as many full source signals as possible
            nloops = nums // stop
            if nloops > 1:
                track[stop:stop * nloops] = tile(track[:stop], nloops - 1)
            # fill up remaining empty track
            res = nums % stop  # last part of unfinished loop
            if res > 0: track[stop * nloops:] = track[:res]

        # The rms value is just an amplification here
        return self.rms * track
コード例 #15
0
ファイル: _settings.py プロジェクト: ma-sadeghi/OpenPNM
 def __setattr__(self, attr, val):
     if attr in self._settings.visible_traits():
         setattr(self._settings, attr, val)
     else:
         val = Trait(val, val.__class__)
         self._settings.add_trait(attr, val)
コード例 #16
0
# (C) Copyright 2005-2021 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!

# traitprefixmap.py --- Example of using the TraitPrefixMap handler

# --[Imports]------------------------------------------------------------------
from traits.api import Trait, TraitPrefixMap

# --[Code]---------------------------------------------------------------------
boolean_map = Trait("true",
                    TraitPrefixMap({
                        "true": 1,
                        "yes": 1,
                        "false": 0,
                        "no": 0
                    }))
コード例 #17
0
class Foo(HasTraits):
    s = Trait("", validator)
コード例 #18
0
from .include import Include
from .item import Item
from .menu import Action
from .table_column import ObjectColumn
from .view import View

#-------------------------------------------------------------------------
#  Trait definitions:
#-------------------------------------------------------------------------

GenericTableFilterRuleOperation = Trait('=', {
    '=': 'eq',
    '<>': 'ne',
    '<': 'lt',
    '<=': 'le',
    '>': 'gt',
    '>=': 'ge',
    'contains': 'contains',
    'starts with': 'starts_with',
    'ends with': 'ends_with'
})

#-------------------------------------------------------------------------
#  'TableFilter' class:
#-------------------------------------------------------------------------


class TableFilter(HasPrivateTraits):
    """ Filter for items displayed in a table.
    """
コード例 #19
0
ファイル: helper_functions.py プロジェクト: victorliun/mayavi
class BarChart(Pipeline):
    """
    Plots vertical glyphs (like bars) scaled vertical, to do
    histogram-like plots.

    This functions accepts a wide variety of inputs, with positions given
    in 2-D or in 3-D.

    **Function signatures**::

        barchart(s, ...)
        barchart(x, y, s, ...)
        barchart(x, y, f, ...)
        barchart(x, y, z, s, ...)
        barchart(x, y, z, f, ...)

    If only one positional argument is passed, it can be a 1-D, 2-D, or 3-D
    array giving the length of the vectors. The positions of the data
    points are deducted from the indices of array, and an
    uniformly-spaced data set is created.

    If 3 positional arguments (x, y, s) are passed the last one must be
    an array s, or a callable, f, that returns an array. x and y give the
    2D coordinates of positions corresponding to the s values.

    If 4 positional arguments (x, y, z, s) are passed, the 3 first are
    arrays giving the 3D coordinates of the data points, and the last one
    is an array s, or a callable, f, that returns an array giving the
    data value.
    """

    _source_function = Callable(vertical_vectors_source)

    _pipeline = [
        VectorsFactory,
    ]

    mode = Trait('cube',
                 bar_mode_dict,
                 desc='The glyph used to represent the bars.')

    lateral_scale = CFloat(0.9,
                           desc='The lateral scale of the glyph, '
                           'in units of the distance between nearest points')

    auto_scale = true(desc='whether to compute automatically the '
                      'lateral scaling of the glyphs. This might be '
                      'computationally expensive.')

    def __call_internal__(self, *args, **kwargs):
        """ Override the call to be able to scale automatically the axis.
        """
        g = Pipeline.__call_internal__(self, *args, **kwargs)
        gs = g.glyph.glyph_source
        # Use a cube source for glyphs.
        if not 'mode' in kwargs:
            gs.glyph_source = gs.glyph_dict['cube_source']
        # Position the glyph tail on the point.
        gs.glyph_position = 'tail'
        gs.glyph_source.center = (0.0, 0.0, 0.5)
        g.glyph.glyph.orient = False
        if not 'color' in kwargs:
            g.glyph.color_mode = 'color_by_scalar'
        if not 'scale_mode' in kwargs:
            g.glyph.scale_mode = 'scale_by_vector_components'
        g.glyph.glyph.clamping = False
        # The auto-scaling code. It involves finding the minimum
        # distance between points, which can be very expensive. We
        # shortcut this calculation for structured data
        if len(args) == 1 or self.auto_scale:
            min_axis_distance = 1
        else:
            x, y, z = g.mlab_source.x, g.mlab_source.y, g.mlab_source.z
            min_axis_distance = \
                    tools._min_axis_distance(x, y, z)
        scale_factor = g.glyph.glyph.scale_factor * min_axis_distance
        lateral_scale = kwargs.pop('lateral_scale', self.lateral_scale)
        try:
            g.glyph.glyph_source.glyph_source.y_length = \
                    lateral_scale / (scale_factor)
            g.glyph.glyph_source.glyph_source.x_length = \
                    lateral_scale / (scale_factor)
        except TraitError:
            " Not all types of glyphs have controlable y_length and x_length"

        return g
コード例 #20
0
class RTraceDomainList(HasTraits):

    label = Str('RTraceDomainField')
    sd = WeakRef(ISDomain)
    position = Enum('nodes', 'int_pnts')
    subfields = List

    def redraw(self):
        '''Delegate the calculation to the pipeline
        '''
        # self.mvp_mgrid_geo.redraw() # 'label_scalars')
        self.mvp_mgrid_geo.rebuild_pipeline(self.vtk_node_structure)

    vtk_node_structure = Property(Instance(tvtk.UnstructuredGrid))

    #@cached_property

    def _get_vtk_node_structure(self):
        self.position = 'nodes'
        return self.vtk_structure

    vtk_ip_structure = Property(Instance(tvtk.UnstructuredGrid))

    #@cached_property

    def _get_vtk_ip_structure(self):
        self.position = 'int_pnts'
        return self.vtk_structure

    vtk_structure = Property(Instance(tvtk.UnstructuredGrid))

    def _get_vtk_structure(self):
        ug = tvtk.UnstructuredGrid()
        cell_array, cell_offsets, cell_types = self.vtk_cell_data
        n_cells = cell_types.shape[0]
        ug.points = self.vtk_X
        vtk_cell_array = tvtk.CellArray()
        vtk_cell_array.set_cells(n_cells, cell_array)
        ug.set_cells(cell_types, cell_offsets, vtk_cell_array)
        return ug

    vtk_X = Property

    def _get_vtk_X(self):
        point_arr_list = []
        for sf in self.subfields:
            if sf.skip_domain:
                continue
            sf.position = self.position
            sf_vtk_X = sf.vtk_X
            if sf_vtk_X.shape[0] == 0:  # all elem are deactivated
                continue
            point_arr_list.append(sf_vtk_X)
        if len(point_arr_list) > 0:
            # print 'point_arr_list ', point_arr_list
            return vstack(point_arr_list)
        else:
            return zeros((0, 3), dtype='float_')

    # point offset to use when more fields are patched together within
    # RTDomainList

    point_offset = Int(0)

    # cell offset to use when more fields are patched together within
    # RTDomainList

    cell_offset = Int(0)

    vtk_cell_data = Property

    def _get_vtk_cell_data(self):
        cell_array_list = []
        cell_offset_list = []
        cell_types_list = []
        point_offset = self.point_offset
        cell_offset = self.cell_offset
        for sf in self.subfields:
            if sf.skip_domain:
                continue
            sf.position = self.position
            sf.point_offset = point_offset
            sf.cell_offset = cell_offset
            cell_array, cell_offsets, cell_types = sf.vtk_cell_data
            cell_array_list.append(cell_array)
            cell_offset_list.append(cell_offsets)
            cell_types_list.append(cell_types)
            point_offset += sf.n_points
            cell_offset += cell_array.shape[0]
        if len(cell_array_list) > 0:
            cell_array = hstack(cell_array_list)
            cell_offsets = hstack(cell_offset_list)
            cell_types = hstack(cell_types_list)
        else:
            cell_array = array([], dtype='int_')
            cell_offsets = array([], dtype='int_')
            cell_types = array([], dtype='int_')
        return (cell_array, cell_offsets, cell_types)

    #-------------------------------------------------------------------------
    # Visualization pipelines
    #-------------------------------------------------------------------------

    mvp_mgrid_geo = Trait(MVUnstructuredGrid)

    #    def _mvp_mgrid_geo_default(self):
    #        return MVUnstructuredGrid( name = 'Response tracer mesh',
    #                                   points = self.vtk_r,
    #                                   cell_data = self.vtk_cell_data,
    #                                    )

    def _mvp_mgrid_geo_default(self):
        return MVUnstructuredGrid(name='Response tracer mesh',
                                  warp=False,
                                  warp_var='')

    view = View(resizable=True)
コード例 #21
0
ファイル: widget.py プロジェクト: pingleewu/ShareRoot
class Part(HasTraits):
    cost = Trait(0.0)
コード例 #22
0
class ImagePlot(Base2DPlot):
    """ A plot based on an image.
    """
    #------------------------------------------------------------------------
    # Data-related traits
    #------------------------------------------------------------------------

    #: Overall alpha value of the image. Ranges from 0.0 for transparent to 1.0
    #: for full intensity.
    alpha = Trait(1.0, Range(0.0, 1.0))

    #: The interpolation method to use when rendering an image onto the GC.
    interpolation = Enum("nearest", "bilinear", "bicubic")

    #: Bool indicating whether x-axis is flipped.
    x_axis_is_flipped = Property(depends_on=['orientation', 'origin'])

    #: Bool indicating whether y-axis is flipped.
    y_axis_is_flipped = Property(depends_on=['orientation', 'origin'])

    #------------------------------------------------------------------------
    # Private traits
    #------------------------------------------------------------------------

    # Are the cache traits valid? If False, new ones need to be computed.
    _image_cache_valid = Bool(False)

    # Cached image of the bmp data (not the bmp data in self.data.value).
    _cached_image = Instance(GraphicsContextArray)

    # Tuple-defined rectangle (x, y, dx, dy) in screen space in which the
    # **_cached_image** is to be drawn.
    _cached_dest_rect = Either(Tuple, List)

    # Bool indicating whether the origin is top-left or bottom-right.
    # The name "principal diagonal" is borrowed from linear algebra.
    _origin_on_principal_diagonal = Property(depends_on='origin')

    #------------------------------------------------------------------------
    # Properties
    #------------------------------------------------------------------------

    @cached_property
    def _get_x_axis_is_flipped(self):
        return ((self.orientation == 'h' and 'right' in self.origin)
                or (self.orientation == 'v' and 'top' in self.origin))

    @cached_property
    def _get_y_axis_is_flipped(self):
        return ((self.orientation == 'h' and 'top' in self.origin)
                or (self.orientation == 'v' and 'right' in self.origin))

    #------------------------------------------------------------------------
    # Event handlers
    #------------------------------------------------------------------------

    def _index_data_changed_fired(self):
        self._image_cache_valid = False
        self.request_redraw()

    def _index_mapper_changed_fired(self):
        self._image_cache_valid = False
        self.request_redraw()

    def _value_data_changed_fired(self):
        self._image_cache_valid = False
        self.request_redraw()

    #------------------------------------------------------------------------
    # Base2DPlot interface
    #------------------------------------------------------------------------

    def _render(self, gc):
        """ Draw the plot to screen.

        Implements the Base2DPlot interface.
        """
        if not self._image_cache_valid:
            self._compute_cached_image()

        scale_x = -1 if self.x_axis_is_flipped else 1
        scale_y = 1 if self.y_axis_is_flipped else -1

        x, y, w, h = self._cached_dest_rect
        if w <= 0 or h <= 0:
            return

        x_center = x + w / 2
        y_center = y + h / 2
        with gc:
            gc.clip_to_rect(self.x, self.y, self.width, self.height)
            gc.set_alpha(self.alpha)

            # Translate origin to the center of the graphics context.
            if self.orientation == "h":
                gc.translate_ctm(x_center, y_center)
            else:
                gc.translate_ctm(y_center, x_center)

            # Flip axes to move origin to the correct position.
            gc.scale_ctm(scale_x, scale_y)

            if self.orientation == "v":
                self._transpose_about_origin(gc)

            # Translate the origin back to its original position.
            gc.translate_ctm(-x_center, -y_center)

            with self._temporary_interp_setting(gc):
                gc.draw_image(self._cached_image, self._cached_dest_rect)

    def map_index(self,
                  screen_pt,
                  threshold=0.0,
                  outside_returns_none=True,
                  index_only=False):
        """ Maps a screen space point to an index into the plot's index
        array(s).

        Implements the AbstractPlotRenderer interface. Uses 0.0 for
        *threshold*, regardless of the passed value.
        """
        # For image plots, treat hittesting threshold as 0.0, because it's
        # the only thing that really makes sense.
        return Base2DPlot.map_index(self, screen_pt, 0.0, outside_returns_none,
                                    index_only)

    #------------------------------------------------------------------------
    # Private methods
    #------------------------------------------------------------------------

    @cached_property
    def _get__origin_on_principal_diagonal(self):
        bottom_right = 'bottom' in self.origin and 'right' in self.origin
        top_left = 'top' in self.origin and 'left' in self.origin
        return bottom_right or top_left

    def _transpose_about_origin(self, gc):
        if self._origin_on_principal_diagonal:
            gc.scale_ctm(-1, 1)
        else:
            gc.scale_ctm(1, -1)
        gc.rotate_ctm(pi / 2)

    @contextmanager
    def _temporary_interp_setting(self, gc):
        if hasattr(gc, "set_interpolation_quality"):
            # Quartz uses interpolation setting on the destination GC.
            interp_quality = QUARTZ_INTERP_QUALITY[self.interpolation]
            gc.set_interpolation_quality(interp_quality)
            yield
        elif hasattr(gc, "set_image_interpolation"):
            # Agg backend uses the interpolation setting of the *source*
            # image to determine the type of interpolation to use when
            # drawing. Temporarily change image's interpolation value.
            old_interp = self._cached_image.get_image_interpolation()
            set_interp = self._cached_image.set_image_interpolation
            try:
                set_interp(self.interpolation)
                yield
            finally:
                set_interp(old_interp)
        else:
            yield

    def _calc_virtual_screen_bbox(self):
        """ Return the rectangle describing the image in screen space
        assuming that the entire image could fit on screen.

        Zoomed-in images will have "virtual" sizes larger than the image.
        Note that vertical orientations flip x- and y-axes such that x is
        vertical and y is horizontal.
        """
        # Upper-right values are always larger than lower-left values,
        # regardless of origin or orientation...
        (lower_left, upper_right) = self.index.get_bounds()
        # ... but if the origin is not 'bottom left', the data-to-screen
        # mapping will flip min and max values.
        x_min, y_min = self.map_screen([lower_left])[0]
        x_max, y_max = self.map_screen([upper_right])[0]
        if x_min > x_max:
            x_min, x_max = x_max, x_min
        if y_min > y_max:
            y_min, y_max = y_max, y_min

        virtual_x_size = x_max - x_min
        virtual_y_size = y_max - y_min

        # Convert to the coordinates of the graphics context, which expects
        # origin to be at the center of a pixel.
        x_min += 0.5
        y_min += 0.5
        return [x_min, y_min, virtual_x_size, virtual_y_size]

    def _compute_cached_image(self, data=None, mapper=None):
        """ Computes the correct screen coordinates and renders an image into
        `self._cached_image`.

        Parameters
        ----------
        data : array
            Image data. If None, image is derived from the `value` attribute.
        mapper : function
            Allows subclasses to transform the displayed values for the visible
            region. This may be used to adapt grayscale images to RGB(A)
            images.
        """
        if data is None:
            data = self.value.data

        virtual_rect = self._calc_virtual_screen_bbox()
        index_bounds, screen_rect = self._calc_zoom_coords(virtual_rect)
        col_min, col_max, row_min, row_max = index_bounds

        view_rect = self.position + self.bounds
        sub_array_size = (col_max - col_min, row_max - row_min)
        screen_rect = trim_screen_rect(screen_rect, view_rect, sub_array_size)

        data = data[row_min:row_max, col_min:col_max]

        if mapper is not None:
            data = mapper(data)

        if len(data.shape) != 3:
            raise RuntimeError("`ImagePlot` requires color images.")

        # Update cached image and rectangle.
        self._cached_image = self._kiva_array_from_numpy_array(data)
        self._cached_dest_rect = screen_rect
        self._image_cache_valid = True

    def _kiva_array_from_numpy_array(self, data):
        if data.shape[2] not in KIVA_DEPTH_MAP:
            msg = "Unknown colormap depth value: {}"
            raise RuntimeError(msg.format(data.shape[2]))
        kiva_depth = KIVA_DEPTH_MAP[data.shape[2]]

        # Data presented to the GraphicsContextArray needs to be contiguous.
        data = np.ascontiguousarray(data)
        return GraphicsContextArray(data, pix_format=kiva_depth)

    def _calc_zoom_coords(self, image_rect):
        """ Calculates the coordinates of a zoomed sub-image.

        Because of floating point limitations, it is not advisable to request a
        extreme level of zoom, e.g., idx or idy > 10^10.

        Parameters
        ----------
        image_rect : 4-tuple
            (x, y, width, height) rectangle describing the pixels bounds of the
            full, **rendered** image. This will be larger than the canvas when
            zoomed in since the full image may not fit on the canvas.

        Returns
        -------
        index_bounds : 4-tuple
            The column and row indices (col_min, col_max, row_min, row_max) of
            the sub-image to be extracted and drawn into `screen_rect`.
        screen_rect : 4-tuple
            (x, y, width, height) rectangle describing the pixels bounds where
            the image will be rendered in the plot.
        """
        ix, iy, image_width, image_height = image_rect
        if 0 in (image_width, image_height) or 0 in self.bounds:
            return (None, None)

        array_bounds = self._array_bounds_from_screen_rect(image_rect)
        col_min, col_max, row_min, row_max = array_bounds
        # Convert array indices back into screen coordinates after its been
        # clipped to fit within the bounds.
        array_width = self.value.get_width()
        array_height = self.value.get_height()
        x_min = float(col_min) / array_width * image_width + ix
        x_max = float(col_max) / array_width * image_width + ix
        y_min = float(row_min) / array_height * image_height + iy
        y_max = float(row_max) / array_height * image_height + iy

        # Flip indexes **after** calculating screen coordinates.
        # The screen coordinates will get flipped in the renderer.
        if self.y_axis_is_flipped:
            row_min = array_height - row_min
            row_max = array_height - row_max
            row_min, row_max = row_max, row_min
        if self.x_axis_is_flipped:
            col_min = array_width - col_min
            col_max = array_width - col_max
            col_min, col_max = col_max, col_min

        index_bounds = list(map(int, [col_min, col_max, row_min, row_max]))
        screen_rect = [x_min, y_min, x_max - x_min, y_max - y_min]
        return index_bounds, screen_rect

    def _array_bounds_from_screen_rect(self, image_rect):
        """ Transform virtual-image rectangle into array indices.

        The virtual-image rectangle is in screen coordinates and can be outside
        the plot bounds. This method converts the rectangle into array indices
        and clips to the plot bounds.
        """
        # Plot dimensions are independent of orientation and origin, but data
        # dimensions vary with orientation. Flip plot dimensions to match data
        # since outputs will be in data space.
        if self.orientation == "h":
            x_min, y_min = self.position
            plot_width, plot_height = self.bounds
        else:
            y_min, x_min = self.position
            plot_height, plot_width = self.bounds

        ix, iy, image_width, image_height = image_rect
        # Screen coordinates of virtual-image that fit into plot window.
        x_min -= ix
        y_min -= iy
        x_max = x_min + plot_width
        y_max = y_min + plot_height

        array_width = self.value.get_width()
        array_height = self.value.get_height()
        # Convert screen coordinates to array indexes
        col_min = floor(float(x_min) / image_width * array_width)
        col_max = ceil(float(x_max) / image_width * array_width)
        row_min = floor(float(y_min) / image_height * array_height)
        row_max = ceil(float(y_max) / image_height * array_height)

        # Clip index bounds to the array bounds.
        col_min = max(col_min, 0)
        col_max = min(col_max, array_width)
        row_min = max(row_min, 0)
        row_max = min(row_max, array_height)

        return col_min, col_max, row_min, row_max
コード例 #23
0
ファイル: table_filter.py プロジェクト: kitchoi/traitsui
from .item import Item
from .menu import Action
from .table_column import ObjectColumn
from .view import View

# -------------------------------------------------------------------------
#  Trait definitions:
# -------------------------------------------------------------------------

GenericTableFilterRuleOperation = Trait(
    "=",
    {
        "=": "eq",
        "<>": "ne",
        "<": "lt",
        "<=": "le",
        ">": "gt",
        ">=": "ge",
        "contains": "contains",
        "starts with": "starts_with",
        "ends with": "ends_with",
    },
)


class TableFilter(HasPrivateTraits):
    """ Filter for items displayed in a table.
    """

    # -------------------------------------------------------------------------
    #  Trait definitions:
    # -------------------------------------------------------------------------
コード例 #24
0
class TextBoxOverlay(AbstractOverlay):
    """ Draws a box with a text in it
    """
    #### Configuration traits ##################################################
    # The text to display in the box.
    text = Str

    # The font to use for the text.
    font = KivaFont("swiss 12")

    # The background color for the box (overrides AbstractOverlay).
    bgcolor = ColorTrait("transparent")

    # The alpha value to apply to **bgcolor**
    alpha = Trait(1.0, None, Float)

    # The color of the outside box.
    border_color = ColorTrait("dodgerblue")

    # The color of the text in the tooltip
    text_color = black_color_trait

    # The thickness of box border.
    border_size = Int(1)

    # Number of pixels of padding around the text within the box.
    padding = Int(5)

    # Alignment of the text in the box:
    #
    # * "ur": upper right
    # * "ul": upper left
    # * "ll": lower left
    # * "lr": lower right
    align = Enum("ur", "ul", "ll", "lr")
    # This allows subclasses to specify an alternate position for the root
    # of the text box.	Must be a sequence of length 2.
    alternate_position = Any

    #### Public 'AbstractOverlay' interface ####################################
    def overlay(self, component, gc, view_bounds=None, mode="normal"):
        """ Draws the box overlaid on another component.

        Overrides AbstractOverlay.
        """
        if not self.visible:
            return

        # draw the label on a transparent box. This allows us to draw
        # different shapes and put the text inside it without the label
        # filling a rectangle on top of it
        label = Label(text=self.text,
                      font=self.font,
                      bgcolor="transparent",
                      color=self.text_color,
                      margin=5)
        width, height = label.get_width_height(gc)
        valign, halign = self.align
        if self.alternate_position:
            x, y = self.alternate_position
            if valign == "u":
                y += self.padding
            else:
                y -= self.padding + height
            if halign == "r":
                x += self.padding
            else:
                x -= self.padding + width
        else:
            if valign == "u":
                y = component.y2 - self.padding - height
            else:
                y = component.y + self.padding
            if halign == "r":
                x = component.x2 - self.padding - width
            else:
                x = component.x + self.padding
        # attempt to get the box entirely within the component
        if x + width > component.width:
            x = max(0, component.width - width)
        if y + height > component.height:
            y = max(0, component.height - height)
        elif y < 0:
            y = 0
        # apply the alpha channel
        color = self.bgcolor_
        if self.bgcolor != "transparent":
            if self.alpha:
                color = list(self.bgcolor_)
                if len(color) == 4:
                    color[3] = self.alpha
                else:
                    color += [self.alpha]
        gc.save_state()
        try:
            gc.translate_ctm(x, y)

            gc.set_line_width(self.border_size)
            gc.set_stroke_color(self.border_color_)
            gc.set_fill_color(color)

            # draw a rounded rectangle
            x = y = 0
            end_radius = 8.0
            gc.begin_path()
            gc.move_to(x + end_radius, y)
            gc.arc_to(x + width, y, x + width, y + end_radius, end_radius)
            gc.arc_to(x + width, y + height, x + width - end_radius,
                      y + height, end_radius)
            gc.arc_to(x, y + height, x, y, end_radius)
            gc.arc_to(x, y, x + width + end_radius, y, end_radius)
            gc.draw_path()

            label.draw(gc)
        finally:
            gc.restore_state()
コード例 #25
0
ファイル: spectra.py プロジェクト: ZhangAustin/acoular
class PowerSpectra(HasPrivateTraits):
    """Provides the cross spectral matrix of multichannel time data.
    
    This class includes the efficient calculation of the full cross spectral
    matrix using the Welch method with windows and overlap. It also contains 
    data and additional properties of this matrix. 
    
    The result is computed only when needed, that is when the :attr:`csm` attribute
    is actually read. Any change in the input data or parameters leads to a
    new calculation, again triggered when csm is read. The result may be 
    cached on disk in HDF5 files and need not to be recomputed during
    subsequent program runs with identical input data and parameters. The
    input data is taken to be identical if the source has identical parameters
    and the same file name in case of that the data is read from a file.
    """

    #: The :class:`~acoular.sources.SamplesGenerator` object that provides the data.
    time_data = Trait(SamplesGenerator, desc="time data object")

    #: Number of samples
    numchannels = Delegate('time_data')

    #: The :class:`~acoular.calib.Calib` object that provides the calibration data,
    #: defaults to no calibration, i.e. the raw time data is used.
    #:
    #: **deprecated**:      use :attr:`~acoular.sources.TimeSamples.calib` property of
    #: :class:`~acoular.sources.TimeSamples` objects
    calib = Instance(Calib)

    #: FFT block size, one of: 128, 256, 512, 1024, 2048 ... 16384,
    #: defaults to 1024.
    block_size = Trait(1024,
                       128,
                       256,
                       512,
                       1024,
                       2048,
                       4096,
                       8192,
                       16384,
                       desc="number of samples per FFT block")

    #: Index of lowest frequency line to compute, integer, defaults to 1,
    #: is used only by objects that fetch the csm, PowerSpectra computes every
    #: frequency line.
    ind_low = Range(1, desc="index of lowest frequency line")

    #: Index of highest frequency line to compute, integer,
    #: defaults to -1 (last possible line for default block_size).
    ind_high = Int(-1, desc="index of highest frequency line")

    #: Window function for FFT, one of:
    #:   * 'Rectangular' (default)
    #:   * 'Hanning'
    #:   * 'Hamming'
    #:   * 'Bartlett'
    #:   * 'Blackman'
    window = Trait('Rectangular', {
        'Rectangular': ones,
        'Hanning': hanning,
        'Hamming': hamming,
        'Bartlett': bartlett,
        'Blackman': blackman
    },
                   desc="type of window for FFT")

    #: Overlap factor for averaging: 'None'(default), '50%', '75%', '87.5%'.
    overlap = Trait('None', {
        'None': 1,
        '50%': 2,
        '75%': 4,
        '87.5%': 8
    },
                    desc="overlap of FFT blocks")

    #: Flag, if true (default), the result is cached in h5 files and need not
    #: to be recomputed during subsequent program runs.
    cached = Bool(True, desc="cached flag")

    #: Number of FFT blocks to average, readonly
    #: (set from block_size and overlap).
    num_blocks = Property(desc="overall number of FFT blocks")

    #: 2-element array with the lowest and highest frequency, readonly.
    freq_range = Property(desc="frequency range")

    #: Array with a sequence of indices for all frequencies
    #: between :attr:`ind_low` and :attr:`ind_high` within the result, readonly.
    indices = Property(desc="index range")

    #: Name of the cache file without extension, readonly.
    basename = Property(depends_on='time_data.digest',
                        desc="basename for cache file")

    #: The cross spectral matrix,
    #: (number of frequencies, numchannels, numchannels) array of complex;
    #: readonly.
    csm = Property(desc="cross spectral matrix")

    # internal identifier
    digest = Property(depends_on=[
        'time_data.digest', 'calib.digest', 'block_size', 'window', 'overlap'
    ], )

    # hdf5 cache file
    h5f = Instance(tables.File, transient=True)

    traits_view = View([
        'time_data@{}',
        'calib@{}',
        [
            'block_size', 'window', 'overlap',
            [
                'ind_low{Low Index}', 'ind_high{High Index}',
                '-[Frequency range indices]'
            ],
            [
                'num_blocks~{Number of blocks}',
                'freq_range~{Frequency range}', '-'
            ], '[FFT-parameters]'
        ],
    ],
                       buttons=OKCancelButtons)

    @property_depends_on('time_data.numsamples, block_size, overlap')
    def _get_num_blocks(self):
        return self.overlap_*self.time_data.numsamples/self.block_size-\
        self.overlap_+1

    @property_depends_on('time_data.sample_freq, block_size, ind_low, ind_high'
                         )
    def _get_freq_range(self):
        try:
            return self.fftfreq()[[self.ind_low, self.ind_high]]
        except IndexError:
            return array([0., 0])

    @property_depends_on('block_size, ind_low, ind_high')
    def _get_indices(self):
        try:
            return arange(self.block_size / 2 + 1,
                          dtype=int)[self.ind_low:self.ind_high]
        except IndexError:
            return range(0)

    @cached_property
    def _get_digest(self):
        return digest(self)

    @cached_property
    def _get_basename(self):
        if 'basename' in self.time_data.all_trait_names():
            return self.time_data.basename
        else:
            return self.time_data.__class__.__name__ + self.time_data.digest

    @property_depends_on('digest')
    def _get_csm(self):
        """
        Main work is done here:
        Cross spectral matrix is either loaded from cache file or
        calculated and then additionally stored into cache.
        """
        # test for dual calibration
        obj = self.time_data  # start with time_data obj
        while obj:
            if 'calib' in obj.all_trait_names():  # at original source?
                if obj.calib and self.calib:
                    if obj.calib.digest == self.calib.digest:
                        self.calib = None  # ignore it silently
                    else:
                        raise ValueError("Non-identical dual calibration for "\
                                    "both TimeSamples and PowerSpectra object")
                obj = None
            else:
                try:
                    obj = obj.source  # traverse down until original data source
                except AttributeError:
                    obj = None
        name = 'csm_' + self.digest
        H5cache.get_cache(self, self.basename)
        #print self.basename
        if not self.cached or not name in self.h5f.root:
            t = self.time_data
            wind = self.window_(self.block_size)
            weight = dot(wind, wind)
            wind = wind[newaxis, :].swapaxes(0, 1)
            numfreq = int(self.block_size / 2 + 1)
            csm_shape = (numfreq, t.numchannels, t.numchannels)
            csmUpper = zeros(csm_shape, 'D')
            #print "num blocks", self.num_blocks
            # for backward compatibility
            if self.calib and self.calib.num_mics > 0:
                if self.calib.num_mics == t.numchannels:
                    wind = wind * self.calib.data[newaxis, :]
                else:
                    raise ValueError(
                            "Calibration data not compatible: %i, %i" % \
                            (self.calib.num_mics, t.numchannels))
            bs = self.block_size
            temp = empty((2 * bs, t.numchannels))
            pos = bs
            posinc = bs / self.overlap_
            for data in t.result(bs):
                ns = data.shape[0]
                temp[bs:bs + ns] = data
                while pos + bs <= bs + ns:
                    ft = fft.rfft(temp[int(pos):int(pos + bs)] * wind, None, 0)
                    calcCSM(
                        csmUpper, ft
                    )  # only upper triangular part of matrix is calculated (for speed reasons)
                    pos += posinc
                temp[0:bs] = temp[bs:]
                pos -= bs

            # create the full csm matrix via transposingand complex conj.
            csmLower = csmUpper.conj().transpose(0, 2, 1)
            [
                fill_diagonal(csmLower[cntFreq, :, :], 0)
                for cntFreq in xrange(csmLower.shape[0])
            ]
            csm = csmLower + csmUpper

            # onesided spectrum: multiplication by 2.0=sqrt(2)^2
            csm = csm * (2.0 / self.block_size / weight / self.num_blocks)

            if self.cached:
                atom = tables.ComplexAtom(8)
                filters = tables.Filters(complevel=5, complib='blosc')
                ac = self.h5f.create_carray(self.h5f.root,
                                            name,
                                            atom,
                                            csm_shape,
                                            filters=filters)
                ac[:] = csm
                return ac
            else:
                return csm
        else:
            return self.h5f.get_node('/', name)

    def fftfreq(self):
        """
        Return the Discrete Fourier Transform sample frequencies.
        
        Returns
        -------
        f : ndarray
            Array of length *block_size/2+1* containing the sample frequencies.
        """
        return abs(fft.fftfreq(self.block_size, 1./self.time_data.sample_freq)\
                    [:int(self.block_size/2+1)])
コード例 #26
0
    # Should labels be added to items in a group?
    show_labels = Bool(True)

    # The default theme to use for a contained item:
    item_theme = ATheme

    # The default theme to use for a contained item's label:
    label_theme = ATheme


#-------------------------------------------------------------------------------
#  Trait definitions:
#-------------------------------------------------------------------------------

# The container trait used by ViewSubElements:
Container = Trait(DefaultViewElement(), ViewElement)

#-------------------------------------------------------------------------------
#  'ViewSubElement' class (abstract):
#-------------------------------------------------------------------------------


class ViewSubElement(ViewElement):
    """ Abstract class representing elements that can be contained in a view.
    """

    #---------------------------------------------------------------------------
    #  Trait definitions:
    #---------------------------------------------------------------------------

    # The object this ViewSubElement is contained in; must be a ViewElement.
コード例 #27
0
ファイル: compound.py プロジェクト: skailasa/traits
class Die(HasTraits):

    # Define a compound trait definition:
    value = Trait(1, Range(1, 6), "one", "two", "three", "four", "five", "six")
コード例 #28
0
class GridPlotContainer(BasePlotContainer):
    """ A GridPlotContainer consists of rows and columns in a tabular format.

    Each cell's width is the same as all other cells in its column, and each
    cell's height is the same as all other cells in its row.

    Although grid layout requires more layout information than a simple
    ordered list, this class keeps components as a simple list and exposes a
    **shape** trait.
    """

    draw_order = Instance(list, args=(DEFAULT_DRAWING_ORDER, ))

    # The amount of space to put on either side of each component, expressed
    # as a tuple (h_spacing, v_spacing).
    spacing = Either(Tuple, List, Array)

    # The vertical alignment of objects that don't span the full height.
    valign = Enum("bottom", "top", "center")

    # The horizontal alignment of objects that don't span the full width.
    halign = Enum("left", "right", "center")

    # The shape of this container, i.e, (rows, columns).  The items in
    # **components** are shuffled appropriately to match this
    # specification.  If there are fewer components than cells, the remaining
    # cells are filled in with spaces.  If there are more components than cells,
    # the remainder wrap onto new rows as appropriate.
    shape = Trait((0, 0), Either(Tuple, List, Array))

    # This property exposes the underlying grid structure of the container,
    # and is the preferred way of setting and reading its contents.
    # When read, this property returns a Numpy array with dtype=object; values
    # for setting it can be nested tuples, lists, or 2-D arrays.
    # The array is in row-major order, so that component_grid[0] is the first
    # row, and component_grid[:,0] is the first column.  The rows are ordered
    # from top to bottom.
    component_grid = Property

    # The internal component grid, in row-major order.  This gets updated
    # when any of the following traits change: shape, components, grid_components
    _grid = Array

    _cached_total_size = Any
    _h_size_prefs = Any
    _v_size_prefs = Any

    class SizePrefs(object):
        """ Object to hold size preferences across spans in a particular
        dimension.  For instance, if SizePrefs is being used for the row
        axis, then each element in the arrays below express sizing information
        about the corresponding column.
        """

        # The maximum size of non-resizable elements in the span.  If an
        # element of this array is 0, then its corresponding span had no
        # non-resizable components.
        fixed_lengths = Array

        # The maximum preferred size of resizable elements in the span.
        # If an element of this array is 0, then its corresponding span
        # had no resizable components with a non-zero preferred size.
        resizable_lengths = Array

        # The direction of resizability associated with this SizePrefs
        # object.  If this SizePrefs is sizing along the X-axis, then
        # direction should be "h", and correspondingly for the Y-axis.
        direction = Enum("h", "v")

        # The index into a size tuple corresponding to our orientation
        # (0 for horizontal, 1 for vertical).  This is derived from
        # **direction** in the constructor.
        index = Int(0)

        def __init__(self, length, direction):
            """ Initializes this prefs object with empty arrays of the given
            length and with the given direction. """
            self.fixed_lengths = zeros(length)
            self.resizable_lengths = zeros(length)
            self.direction = direction
            if direction == "h":
                self.index = 0
            else:
                self.index = 1
            return

        def update_from_component(self, component, index):
            """ Given a component at a particular index along this SizePref's
            axis, integrates the component's resizability and sizing information
            into self.fixed_lengths and self.resizable_lengths. """
            resizable = self.direction in component.resizable
            pref_size = component.get_preferred_size()
            self.update_from_pref_size(pref_size[self.index], index, resizable)

        def update_from_pref_size(self, pref_length, index, resizable):
            if resizable:
                if pref_length > self.resizable_lengths[index]:
                    self.resizable_lengths[index] = pref_length
            else:
                if pref_length > self.fixed_lengths[index]:
                    self.fixed_lengths[index] = pref_length
            return

        def get_preferred_size(self):
            return amax((self.fixed_lengths, self.resizable_lengths), axis=0)

        def compute_size_array(self, size):
            """ Given a length along the axis corresponding to this SizePref,
            returns an array of lengths to assign each cell, taking into account
            resizability and preferred sizes.
            """
            # There are three basic cases for each column:
            #   1. size < total fixed size
            #   2. total fixed size < size < fixed size + resizable preferred size
            #   3. fixed size + resizable preferred size < size
            #
            # In all cases, non-resizable components get their full width.
            #
            # For resizable components with non-zero preferred size, the following
            # actions are taken depending on case:
            #   case 1: They get sized to 0.
            #   case 2: They get a fraction of their preferred size, scaled based on
            #           the amount of remaining space after non-resizable components
            #           get their full size.
            #   case 3: They get their full preferred size.
            #
            # For resizable components with no preferred size (indicated in our scheme
            # by having a preferred size of 0), the following actions are taken
            # depending on case:
            #   case 1: They get sized to 0.
            #   case 2: They get sized to 0.
            #   case 3: All resizable components with no preferred size split the
            #           remaining space evenly, after fixed width and resizable
            #           components with preferred size get their full size.
            fixed_lengths = self.fixed_lengths
            resizable_lengths = self.resizable_lengths
            return_lengths = zeros_like(fixed_lengths)

            fixed_size = sum(fixed_lengths)
            fixed_length_indices = fixed_lengths > resizable_lengths
            resizable_indices = resizable_lengths > fixed_lengths
            fully_resizable_indices = (resizable_lengths + fixed_lengths == 0)
            preferred_size = sum(fixed_lengths[fixed_length_indices]) + \
                                    sum(resizable_lengths[~fixed_length_indices])

            # Regardless of the relationship between available space and
            # resizable preferred sizes, columns/rows where the non-resizable
            # component is largest will always get that amount of space.
            return_lengths[fixed_length_indices] = fixed_lengths[
                fixed_length_indices]

            if size <= fixed_size:
                # We don't use fixed_length_indices here because that mask is
                # just where non-resizable components were larger than resizable
                # ones.  If our allotted size is less than the total fixed size,
                # then we should give all non-resizable components their desired
                # size.
                indices = fixed_lengths > 0
                return_lengths[indices] = fixed_lengths[indices]
                return_lengths[~indices] = 0

            elif size > fixed_size and (fixed_lengths >
                                        resizable_lengths).all():
                # If we only have to consider non-resizable lengths, and we have
                # extra space available, then we need to give each column an
                # amount of extra space corresponding to its size.
                desired_space = sum(fixed_lengths)
                if desired_space > 0:
                    scale = size / desired_space
                    return_lengths = (fixed_lengths * scale).astype(int)

            elif size <= preferred_size or not fully_resizable_indices.any():
                # If we don't have enough room to give all the non-fully resizable
                # components their preferred size, or we have more than enough
                # room for them and no fully resizable components to take up
                # the extra space, then we just scale the resizable components
                # up or down based on the amount of extra space available.
                delta_lengths = resizable_lengths[resizable_indices] - \
                                        fixed_lengths[resizable_indices]
                desired_space = sum(delta_lengths)
                if desired_space > 0:
                    avail_space = size - sum(
                        fixed_lengths)  #[fixed_length_indices])
                    scale = avail_space / desired_space
                    return_lengths[resizable_indices] = (fixed_lengths[resizable_indices] + \
                            scale * delta_lengths).astype(int)

            elif fully_resizable_indices.any():
                # We have enough room to fit all the non-resizable components
                # as well as components with preferred sizes, and room left
                # over for the fully resizable components.  Give the resizable
                # components their desired amount of space, and then give the
                # remaining space to the fully resizable components.
                return_lengths[resizable_indices] = resizable_lengths[
                    resizable_indices]
                avail_space = size - preferred_size
                count = sum(fully_resizable_indices)
                space = avail_space / count
                return_lengths[fully_resizable_indices] = space

            else:
                raise RuntimeError("Unhandled sizing case in GridContainer")

            return return_lengths

    def get_preferred_size(self, components=None):
        """ Returns the size (width,height) that is preferred for this component.

        Overrides PlotComponent.
        """
        if self.fixed_preferred_size is not None:
            return self.fixed_preferred_size

        if components is None:
            components = self.component_grid
        else:
            # Convert to array; hopefully it is a list or tuple of list/tuples
            components = array(components)

        # These arrays track the maximum widths in each column and maximum
        # height in each row.
        numrows, numcols = self.shape

        no_visible_components = True
        self._h_size_prefs = GridPlotContainer.SizePrefs(numcols, "h")
        self._v_size_prefs = GridPlotContainer.SizePrefs(numrows, "v")
        self._pref_size_cache = {}
        for i, row in enumerate(components):
            for j, component in enumerate(row):
                if not self._should_layout(component):
                    continue
                else:
                    no_visible_components = False
                    self._h_size_prefs.update_from_component(component, j)
                    self._v_size_prefs.update_from_component(component, i)

        total_width = sum(
            self._h_size_prefs.get_preferred_size()) + self.hpadding
        total_height = sum(
            self._v_size_prefs.get_preferred_size()) + self.vpadding
        total_size = array([total_width, total_height])

        # Account for spacing.  There are N+1 of spaces, where N is the size in
        # each dimension.
        if self.spacing is None:
            spacing = zeros(2)
        else:
            spacing = array(self.spacing)
        total_spacing = array(
            components.shape[::-1]) * spacing * 2 * (total_size > 0)
        total_size += total_spacing

        for orientation, ndx in (("h", 0), ("v", 1)):
            if (orientation not in self.resizable) and \
               (orientation not in self.fit_components):
                total_size[ndx] = self.outer_bounds[ndx]
            elif no_visible_components or (total_size[ndx] == 0):
                total_size[ndx] = self.default_size[ndx]

        self._cached_total_size = total_size
        if self.resizable == "":
            return self.outer_bounds
        else:
            return self._cached_total_size

    def _do_layout(self):
        # If we don't have cached size_prefs, then we need to call
        # get_preferred_size to build them.
        if self._cached_total_size is None:
            self.get_preferred_size()

        # If we need to fit our components, then rather than using our
        # currently assigned size to do layout, we use the preferred
        # size we computed from our components.
        size = array(self.bounds)
        if self.fit_components != "":
            self.get_preferred_size()
            if "h" in self.fit_components:
                size[0] = self._cached_total_size[0] - self.hpadding
            if "v" in self.fit_components:
                size[1] = self._cached_total_size[1] - self.vpadding

        # Compute total_spacing and spacing, which are used in computing
        # the bounds and positions of all the components.
        shape = array(self._grid.shape).transpose()
        if self.spacing is None:
            spacing = array([0, 0])
        else:
            spacing = array(self.spacing)
        total_spacing = spacing * 2 * shape

        # Compute the total space used by non-resizable and resizable components
        # with non-zero preferred sizes.
        widths = self._h_size_prefs.compute_size_array(size[0] -
                                                       total_spacing[0])
        heights = self._v_size_prefs.compute_size_array(size[1] -
                                                        total_spacing[1])

        # Set the baseline h and v positions for each cell.  Resizable components
        # will get these as their position, but non-resizable components will have
        # to be aligned in H and V.
        summed_widths = cumsum(hstack(([0], widths[:-1])))
        summed_heights = cumsum(hstack(([0], heights[-1:0:-1])))
        h_positions = (2 * (arange(self._grid.shape[1]) + 1) -
                       1) * spacing[0] + summed_widths
        v_positions = (2 * (arange(self._grid.shape[0]) + 1) -
                       1) * spacing[1] + summed_heights
        v_positions = v_positions[::-1]

        # Loop over all rows and columns, assigning position, setting bounds for
        # resizable components, and aligning non-resizable ones
        valign = self.valign
        halign = self.halign
        for j, row in enumerate(self._grid):
            for i, component in enumerate(row):
                if not self._should_layout(component):
                    continue

                r = component.resizable
                x = h_positions[i]
                y = v_positions[j]
                w = widths[i]
                h = heights[j]

                if "v" not in r:
                    # Component is not vertically resizable
                    if valign == "top":
                        y += h - component.outer_height
                    elif valign == "center":
                        y += (h - component.outer_height) / 2
                if "h" not in r:
                    # Component is not horizontally resizable
                    if halign == "right":
                        x += w - component.outer_width
                    elif halign == "center":
                        x += (w - component.outer_width) / 2

                component.outer_position = [x, y]
                bounds = list(component.outer_bounds)
                if "h" in r:
                    bounds[0] = w
                if "v" in r:
                    bounds[1] = h

                component.outer_bounds = bounds
                component.do_layout()

        return

    def _reflow_layout(self):
        """ Re-computes self._grid based on self.components and self.shape.
        Adjusts self.shape accordingly.
        """
        numcells = self.shape[0] * self.shape[1]
        if numcells < len(self.components):
            numrows, numcols = divmod(len(self.components), self.shape[0])
            self.shape = (numrows, numcols)
        grid = array(self.components, dtype=object)
        grid.resize(self.shape)
        grid[grid == 0] = None
        self._grid = grid
        self._layout_needed = True
        return

    def _shape_changed(self, old, new):
        self._reflow_layout()

    def __components_changed(self, old, new):
        self._reflow_layout()

    def __components_items_changed(self, event):
        self._reflow_layout()

    def _get_component_grid(self):
        return self._grid

    def _set_component_grid(self, val):
        grid = array(val)
        grid_set = set(grid.flatten())

        # Figure out which of the components in the component_grid are new,
        # and which have been removed.
        existing = set(array(self._grid).flatten())
        new = grid_set - existing
        removed = existing - grid_set

        for component in removed:
            if component is not None:
                component.container = None
        for component in new:
            if component is not None:
                if component.container is not None:
                    component.container.remove(component)
                component.container = self

        self.set(shape=grid.shape, trait_change_notify=False)
        self._components = list(grid.flatten())

        if self._should_compact():
            self.compact()

        self.invalidate_draw()
        return
コード例 #29
0
class UndoItem(AbstractUndoItem):
    """ A change to an object trait, which can be undone.
    """
    #---------------------------------------------------------------------------
    #  Trait definitions:
    #---------------------------------------------------------------------------

    # Object the change occurred on
    object = Trait(HasTraits)
    # Name of the trait that changed
    name = Str
    # Old value of the changed trait
    old_value = Property
    # New value of the changed trait
    new_value = Property

    #---------------------------------------------------------------------------
    #  Implementation of the 'old_value' and 'new_value' properties:
    #---------------------------------------------------------------------------

    def _get_old_value(self):
        return self._old_value

    def _set_old_value(self, value):
        if isinstance(value, list):
            value = value[:]
        self._old_value = value

    def _get_new_value(self):
        return self._new_value

    def _set_new_value(self, value):
        if isinstance(value, list):
            value = value[:]
        self._new_value = value

    #---------------------------------------------------------------------------
    #  Undoes the change:
    #---------------------------------------------------------------------------

    def undo(self):
        """ Undoes the change.
        """
        try:
            setattr(self.object, self.name, self.old_value)
        except:
            pass

    #---------------------------------------------------------------------------
    #  Re-does the change:
    #---------------------------------------------------------------------------

    def redo(self):
        """ Re-does the change.
        """
        try:
            setattr(self.object, self.name, self.new_value)
        except:
            pass

    #---------------------------------------------------------------------------
    #  Merges two undo items if possible:
    #---------------------------------------------------------------------------

    def merge_undo(self, undo_item):
        """ Merges two undo items if possible.
        """
        # Undo items are potentially mergeable only if they are of the same
        # class and refer to the same object trait, so check that first:
        if (isinstance(undo_item, self.__class__)
                and (self.object is undo_item.object)
                and (self.name == undo_item.name)):
            v1 = self.new_value
            v2 = undo_item.new_value
            t1 = type(v1)
            if t1 is type(v2):

                if isinstance(t1, str):
                    # Merge two undo items if they have new values which are
                    # strings which only differ by one character (corresponding
                    # to a single character insertion, deletion or replacement
                    # operation in a text editor):
                    n1 = len(v1)
                    n2 = len(v2)
                    n = min(n1, n2)
                    i = 0
                    while (i < n) and (v1[i] == v2[i]):
                        i += 1
                    if v1[i + (n2 <= n1):] == v2[i + (n2 >= n1):]:
                        self.new_value = v2
                        return True

                elif isinstance(v1, collections.Sequence):
                    # Merge sequence types only if a single element has changed
                    # from the 'original' value, and the element type is a
                    # simple Python type:
                    v1 = self.old_value
                    if isinstance(v1, collections.Sequence):
                        # Note: wxColour says it's a sequence type, but it
                        # doesn't support 'len', so we handle the exception
                        # just in case other classes have similar behavior:
                        try:
                            if len(v1) == len(v2):
                                diffs = 0
                                for i, item in enumerate(v1):
                                    titem = type(item)
                                    item2 = v2[i]
                                    if ((titem not in SimpleTypes)
                                            or (titem is not type(item2))
                                            or (item != item2)):
                                        diffs += 1
                                        if diffs >= 2:
                                            return False
                                if diffs == 0:
                                    return False
                                self.new_value = v2
                                return True
                        except:
                            pass

                elif t1 in NumericTypes:
                    # Always merge simple numeric trait changes:
                    self.new_value = v2
                    return True
        return False

    #---------------------------------------------------------------------------
    #  Returns a 'pretty print' form of the object:
    #---------------------------------------------------------------------------

    def __repr__(self):
        """ Returns a "pretty print" form of the object.
        """
        n = self.name
        cn = self.object.__class__.__name__
        return 'undo( %s.%s = %s )\nredo( %s.%s = %s )' % (
            cn, n, self.old_value, cn, n, self.new_value)
コード例 #30
0
class FETSLSEval(FETSEval):

    x_slice = slice(0, 0)
    parent_fets = Instance(FETSEval)

    nip_disc = Int(0)  # number of integration points on the discontinuity

    def setup(self, sctx, n_ip):
        '''
        overloading the default method
        mats state array has to account for different number of ip in elements
        Perform the setup in the all integration points.
        TODO: original setup can be used after adaptation the ip_coords param
        '''
        #        print 'n_ip ', n_ip
        #        print 'self.m_arr_size ',self.m_arr_size
        #        print 'shape ',sctx.elem_state_array.shape
        for i in range(n_ip):
            sctx.mats_state_array = sctx.elem_state_array[(
                i * self.m_arr_size):((i + 1) * self.m_arr_size)]
            self.mats_eval.setup(sctx)

    n_nodes = Property  # TODO: define dependencies

    @cached_property
    def _get_n_nodes(self):
        return self.parent_fets.n_e_dofs / self.parent_fets.n_nodal_dofs

    # dots_class = DOTSUnstructuredEval
    dots_class = Class(DOTSEval)

    int_order = Int(1)

    mats_eval = Delegate('parent_fets')
    mats_eval_pos = Trait(None, Instance(IMATSEval))
    mats_eval_neg = Trait(None, Instance(IMATSEval))
    mats_eval_disc = Trait(None, Instance(IMATSEval))
    dim_slice = Delegate('parent_fets')

    dof_r = Delegate('parent_fets')
    geo_r = Delegate('parent_fets')
    n_nodal_dofs = Delegate('parent_fets')
    n_e_dofs = Delegate('parent_fets')

    get_dNr_mtx = Delegate('parent_fets')
    get_dNr_geo_mtx = Delegate('parent_fets')

    get_N_geo_mtx = Delegate('parent_fets')

    def get_B_mtx(self, r_pnt, X_mtx, node_ls_values, r_ls_value):
        B_mtx = self.parent_fets.get_B_mtx(r_pnt, X_mtx)
        return B_mtx

    def get_u(self, sctx, u):
        N_mtx = self.parent_fets.get_N_mtx(sctx.loc)
        return dot(N_mtx, u)

    def get_eps_eng(self, sctx, u):
        B_mtx = self.parent_fets.get_B_mtx(sctx.loc, sctx.X)
        return dot(B_mtx, u)

    dof_r = Delegate('parent_fets')
    geo_r = Delegate('parent_fets')

    node_ls_values = Array(float)

    tri_subdivision = Int(0)

    def get_triangulation(self, point_set):
        dim = point_set[0].shape[1]
        n_add = 3 - dim
        if dim == 1:  # sideway for 1D
            structure = [
                array([
                    min(point_set[0]),
                    max(point_set[0]),
                    min(point_set[1]),
                    max(point_set[1])
                ],
                      dtype=float),
                array([[0, 1], [2, 3]], dtype=int)
            ]
            return structure
        points_list = []
        triangles_list = []
        point_offset = 0
        for pts in point_set:
            if self.tri_subdivision == 1:
                new_pt = average(pts, 0)
                pts = vstack((pts, new_pt))
            if n_add > 0:
                points = hstack(
                    [pts, zeros([pts.shape[0], n_add], dtype='float_')])
            # Create a polydata with the points we just created.
            profile = tvtk.PolyData(points=points)

            # Perform a 2D Delaunay triangulation on them.
            delny = tvtk.Delaunay2D(input=profile, offset=1.e1)
            tri = delny.output
            tri.update()  # initiate triangulation
            triangles = array(tri.polys.data, dtype=int_)
            pt = tri.points.data
            tri = (triangles.reshape((triangles.shape[0] / 4), 4))[:, 1:]
            points_list += list(pt)
            triangles_list += list(tri + point_offset)
            point_offset += len(unique(tri))  # Triangulation
        points = array(points_list)
        triangles = array(triangles_list)
        return [points, triangles]

    vtk_point_ip_map = Property(Array(Int))

    def _get_vtk_point_ip_map(self):
        '''
        mapping of the visualization point to the integration points
        according to mutual proximity in the local coordinates
        '''
        vtk_pt_arr = zeros((1, 3), dtype='float_')
        ip_map = zeros(self.vtk_r.shape[0], dtype='int_')
        for i, vtk_pt in enumerate(self.vtk_r):
            vtk_pt_arr[0, self.dim_slice] = vtk_pt
            # get the nearest ip_coord
            ip_map[i] = argmin(cdist(vtk_pt_arr, self.ip_coords))
        return array(ip_map)

    def get_ip_coords(self, int_triangles, int_order):
        '''Get the array of integration points'''
        gps = []
        points, triangles = int_triangles
        if triangles.shape[1] == 1:  # 0D - points
            if int_order == 1:
                gps.append(points[0])
            else:
                raise TraitError('does not make sense')
        elif triangles.shape[1] == 2:  # 1D - lines
            if int_order == 1:
                for id in triangles:
                    gp = average(points[ix_(id)], 0)
                    gps.append(gp)
            elif int_order == 2:
                weigths = array([[0.21132486540518713, 0.78867513459481287],
                                 [0.78867513459481287, 0.21132486540518713]])
                for id in triangles:
                    gps += average(points[ix_(id)], 0, weigths[0]), \
                            average(points[ix_(id)], 0, weigths[1])
            else:
                raise NotImplementedError
        elif triangles.shape[1] == 3:  # 2D - triangles
            if int_order == 1:
                for id in triangles:
                    gp = average(points[ix_(id)], 0)
                    # print "gp ",gp
                    gps.append(gp)
            elif int_order == 2:
                raise NotImplementedError
            elif int_order == 3:
                weigths = array([[0.6, 0.2, 0.2], [0.2, 0.6, 0.2],
                                 [0.2, 0.2, 0.6]])
                for id in triangles:
                    gps += average(points[ix_(id)], 0), \
                        average(points[ix_(id)], 0, weigths[0]), \
                        average(points[ix_(id)], 0, weigths[1]), \
                        average(points[ix_(id)], 0, weigths[2])

            elif int_order == 4:
                raise NotImplementedError
            elif int_order == 5:
                weigths = array([[0.0597158717, 0.4701420641, 0.4701420641], \
                                 [0.4701420641, 0.0597158717, 0.4701420641], \
                                 [0.4701420641, 0.4701420641, 0.0597158717], \
                                 [0.7974269853, 0.1012865073, 0.1012865073], \
                                 [0.1012865073, 0.7974269853, 0.1012865073], \
                                 [0.1012865073, 0.1012865073, 0.7974269853]])
                for id in triangles:
                    weigts_sum = False  # for debug
                    gps += average(points[ix_(id)], 0), \
                         average(points[ix_(id)], 0, weigths[0], weigts_sum), \
                         average(points[ix_(id)], 0, weigths[1], weigts_sum), \
                         average(points[ix_(id)], 0, weigths[2], weigts_sum), \
                         average(points[ix_(id)], 0, weigths[3], weigts_sum), \
                         average(points[ix_(id)], 0, weigths[4], weigts_sum), \
                         average(points[ix_(id)], 0, weigths[5], weigts_sum)
            else:
                raise NotImplementedError
        elif triangles.shape[1] == 4:  # 3D - tetrahedrons
            raise NotImplementedError
        else:
            raise TraitError('unsupported geometric form with %s nodes ' %
                             triangles.shape[1])
        return array(gps, dtype='float_')

    def get_ip_weights(self, int_triangles, int_order):
        '''Get the array of integration points'''
        gps = []
        points, triangles = int_triangles
        if triangles.shape[1] == 1:  # 0D - points
            if int_order == 1:
                gps.append(1.)
            else:
                raise TraitError('does not make sense')
        elif triangles.shape[1] == 2:  # 1D - lines
            if int_order == 1:
                for id in triangles:
                    r_pnt = points[ix_(id)]
                    J_det_ip = norm(r_pnt[1] - r_pnt[0]) * 0.5
                    gp = 2. * J_det_ip
                    gps.append(gp)
            elif int_order == 2:
                for id in triangles:
                    r_pnt = points[ix_(id)]
                    J_det_ip = norm(r_pnt[1] - r_pnt[0]) * 0.5
                    gps += J_det_ip, J_det_ip
            else:
                raise NotImplementedError
        elif triangles.shape[1] == 3:  # 2D - triangles
            if int_order == 1:
                for id in triangles:
                    r_pnt = points[ix_(id)]
                    J_det_ip = self._get_J_det_ip(r_pnt)
                    gp = 1. * J_det_ip
                    # print "gp ",gp
                    gps.append(gp)
            elif int_order == 2:
                raise NotImplementedError
            elif int_order == 3:
                for id in triangles:
                    r_pnt = points[ix_(id)]
                    J_det_ip = self._get_J_det_ip(r_pnt)
                    gps += -0.5625 * J_det_ip, \
                            0.52083333333333337 * J_det_ip, \
                            0.52083333333333337 * J_det_ip, \
                            0.52083333333333337 * J_det_ip
            elif int_order == 4:
                raise NotImplementedError
            elif int_order == 5:
                for id in triangles:
                    r_pnt = points[ix_(id)]
                    J_det_ip = self._get_J_det_ip(r_pnt)
                    gps += 0.225 * J_det_ip, 0.1323941527 * J_det_ip, \
                            0.1323941527 * J_det_ip, 0.1323941527 * J_det_ip, \
                            0.1259391805 * J_det_ip, 0.1259391805 * J_det_ip, \
                            0.1259391805 * J_det_ip
            else:
                raise NotImplementedError
        elif triangles.shape[1] == 4:  # 3D - tetrahedrons
            raise NotImplementedError
        else:
            raise TraitError('unsupported geometric form with %s nodes ' %
                             triangles.shape[1])
        return array(gps, dtype='float_')

    def _get_J_det_ip(self, r_pnt):
        '''
        Helper function 
        just for 2D
        #todo:3D
        @param r_pnt:
        '''
        dNr_geo = self.dNr_geo_triangle
        return det(dot(dNr_geo,
                       r_pnt[:, :2])) / 2.  # factor 2 due to triangular form

    dNr_geo_triangle = Property(Array(float))

    @cached_property
    def _get_dNr_geo_triangle(self):
        dN_geo = array([[-1., 1., 0.], [-1., 0., 1.]], dtype='float_')
        return dN_geo

    def get_corr_pred(self,
                      sctx,
                      u,
                      du,
                      tn,
                      tn1,
                      u_avg=None,
                      B_mtx_grid=None,
                      J_det_grid=None,
                      ip_coords=None,
                      ip_weights=None):
        '''
        Corrector and predictor evaluation.

        @param u current element displacement vector
        '''
        if J_det_grid == None or B_mtx_grid == None:
            X_mtx = sctx.X

        show_comparison = True
        if ip_coords == None:
            ip_coords = self.ip_coords
            show_comparison = False
        if ip_weights == None:
            ip_weights = self.ip_weights

        # ## Use for Jacobi Transformation

        n_e_dofs = self.n_e_dofs
        K = zeros((n_e_dofs, n_e_dofs))
        F = zeros(n_e_dofs)
        sctx.fets_eval = self
        ip = 0

        for r_pnt, wt in zip(ip_coords, ip_weights):
            # r_pnt = gp[0]
            sctx.r_pnt = r_pnt
            # caching cannot be switched off in the moment
            #            if J_det_grid == None:
            #                J_det = self._get_J_det( r_pnt, X_mtx )
            #            else:
            #                J_det = J_det_grid[ip, ... ]
            #            if B_mtx_grid == None:
            #                B_mtx = self.get_B_mtx( r_pnt, X_mtx )
            #            else:
            #                B_mtx = B_mtx_grid[ip, ... ]
            J_det = J_det_grid[ip, ...]
            B_mtx = B_mtx_grid[ip, ...]

            eps_mtx = dot(B_mtx, u)
            d_eps_mtx = dot(B_mtx, du)
            sctx.mats_state_array = sctx.elem_state_array[ip *
                                                          self.m_arr_size:(ip +
                                                                           1) *
                                                          self.m_arr_size]
            # print 'elem state ', sctx.elem_state_array
            # print 'mats state ', sctx.mats_state_array
            sctx.r_ls = sctx.ls_val[ip]
            sig_mtx, D_mtx = self.get_mtrl_corr_pred(sctx, eps_mtx, d_eps_mtx,
                                                     tn, tn1)
            k = dot(B_mtx.T, dot(D_mtx, B_mtx))
            k *= (wt * J_det)
            K += k
            f = dot(B_mtx.T, sig_mtx)
            f *= (wt * J_det)
            F += f
            ip += 1

        return F, K

    def get_J_det(self, r_pnt, X_mtx, ls_nodes,
                  ls_r):  # unified interface for caching
        return array(self._get_J_det(r_pnt, X_mtx), dtype='float_')

    def get_mtrl_corr_pred(self, sctx, eps_mtx, d_eps, tn, tn1):
        ls = sctx.r_ls
        if ls == 0. and self.mats_eval_disc:
            sig_mtx, D_mtx = self.mats_eval_disc.get_corr_pred(
                sctx,
                eps_mtx,
                d_eps,
                tn,
                tn1,
            )
        elif ls > 0. and self.mats_eval_pos:
            sig_mtx, D_mtx = self.mats_eval_pos.get_corr_pred(
                sctx,
                eps_mtx,
                d_eps,
                tn,
                tn1,
            )
        elif ls < 0. and self.mats_eval_neg:
            sig_mtx, D_mtx = self.mats_eval_neg.get_corr_pred(
                sctx,
                eps_mtx,
                d_eps,
                tn,
                tn1,
            )
        else:
            sig_mtx, D_mtx = self.mats_eval.get_corr_pred(
                sctx,
                eps_mtx,
                d_eps,
                tn,
                tn1,
            )
        return sig_mtx, D_mtx