class ShapeSelector(HasTraits):
    select = Enum(*[cls.__name__ for cls in Shape.__subclasses__()])
    shape = Instance(Shape)

    view = View(VGroup(
        Item("select", show_label=False),
        VSplit(Item("shape",
                    style="custom",
                    editor=InstanceEditor(view="view")),
               Item("shape",
                    style="custom",
                    editor=InstanceEditor(view="view_info")),
               show_labels=False)),
                width=350,
                height=300,
                resizable=True)

    def __init__(self, **traits):
        super(ShapeSelector, self).__init__(**traits)
        self._select_changed()

    def _select_changed(self):
        klass = [
            c for c in Shape.__subclasses__() if c.__name__ == self.select
        ][0]
        self.shape = klass()
Exemple #2
0
class Trace(PolyLine):
    x = Instance(Expression)
    y = Instance(Expression)
    z = Instance(Expression)
    #point  = Instance(Expression)
    length = Int(0)

    traits_view = View(
        Item(name='length', label='Frame'),
        Item(name='x', style='custom'),
        Item(name='y', style='custom'),
        Item(name='z', style='custom'),
        #Item(name = 'points', style = 'custom'),
        Item(name='properties',
             editor=InstanceEditor(),
             label='Render properties'),
        title='Line properties')

    def __init__(self, *args, **kwargs):
        PolyLine.__init__(self, *args, **kwargs)

    def update(self):
        x = self.x.get_array(first=-self.length)
        y = self.y.get_array(first=-self.length)
        z = self.z.get_array(first=-self.length)
        self.points = array([x, y, z]).T
        #self.points = self.point.get_array(first=-self.length) #array([x,y,z]).T
        #print self.point.get_array(first=-self.length).shape
        super(Trace, self).update()
Exemple #3
0
class AnalogInputViewer(traits.HasTraits):
    channels = traits.List
    usb_device_number2index = traits.Property(depends_on='channels')

    @traits.cached_property
    def _get_usb_device_number2index(self):
        result = {}
        for i, channel in enumerate(self.channels):
            result[channel.device_channel_num] = i
        return result

    traits_view = View(
        Group(
            Item(
                'channels',
                style='custom',
                editor=ListEditor(rows=3,
                                  editor=InstanceEditor(),
                                  style='custom'),
                resizable=True,
            )),
        resizable=True,
        width=800,
        height=600,
        title='Analog Input',
    )

    def __init__(self, *args, **kwargs):
        super(AnalogInputViewer, self).__init__(*args, **kwargs)
        for usb_channel_num in [0, 1, 2, 3]:
            self.channels.append(
                AnalogInputChannelViewer(device_channel_num=usb_channel_num))
Exemple #4
0
class Frame(HasTraits):
    parent = This
    T = Instance(Expression)
    name = Str("")
    variables = DelegatesTo('parent')

    traits_view = View(Item(name='name'),
                       Item(name='parent',
                            label='Base',
                            editor=InstanceEditor(label="Frame")),
                       Item(name='T', label='Matrix4x4', style='custom'),
                       title='Frame properties')

    def evalT(self):
        if self.T.get_curr_value() != None and self.parent.evalT() != None:
            return self.parent.evalT() * self.T.get_curr_value()
        else:
            return None

    def __init__(self, parent, T, name=""):
        self.name = name
        self.parent = parent
        if isinstance(T, Expression):
            self.T = T
        else:
            self.T = self.variables.new_expression(T)
Exemple #5
0
 def default_traits_view(self):
     """The default traits view of the Engine View.
     """
     view = View(HSplit(
                     Item('engine', 
                         id='mayavi.engine_rich_view.pipeline_view', 
                         springy=True,
                         resizable=True,
                         editor=self.tree_editor, 
                         dock='tab',
                         label='Pipeline'), 
                     Item('engine',
                         id='mayavi.engine_rich_view.current_selection', 
                         editor=InstanceEditor(
                                     view='current_selection_view'),
                         springy=True,
                         resizable=True,
                         style='custom'),
                 show_labels=False,
                 id='mayavi.engine_rich_view_group',
                 ),
                 id='enthought.mayavi.engine_rich_view',
                 help=False,
                 resizable=True,
                 undo=False,
                 revert=False,
                 ok=False,
                 cancel=False,
                 title='Mayavi pipeline',
                 icon=self.icon,
                 toolbar=self.toolbar,
                 handler=EngineRichViewHandler)
     return view
Exemple #6
0
class PolyLine(Primitive):
    source = Instance(tvtk.PolyData)
    points = Instance(numpy.ndarray)
    traits_view = View(Item(name='parent', label='Frame'),
                       Item(name='T', label='Matrix4x4', style='custom'),
                       Item(name='properties',
                            editor=InstanceEditor(),
                            label='Render properties'),
                       title='Line properties')

    def __init__(self, *args, **kwargs):
        Primitive.__init__(self, **kwargs)
        self.source = tvtk.PolyData()
        self.mapper = tvtk.PolyDataMapper(input=self.source)
        self.actor = tvtk.Actor(mapper=self.mapper)
        self.handle_arguments(*args, **kwargs)
        #kwargs.get('foo', 12)  fnoble cleverity

    def _points_changed(self, old, new):
        npoints = len(self.points)
        if npoints < 2:
            return
        lines = np.zeros((npoints - 1, 2), dtype=int)
        lines[:, 0] = np.arange(0, npoints - 1)
        lines[:, 1] = np.arange(1, npoints)
        self.source.points = self.points
        self.source.lines = lines
Exemple #7
0
class PupilGenerator(HasTraits):
    wavelength = Float(700)
    NA = Float(1.49)
    n = Float(1.51)
    pupilSizeX = Int(61)
    pixelSize = Float(70)

    pupils = List(_pupils)

    aberations = List(ZernikeMode, value=[ZernikeMode(i) for i in range(25)])

    basePupil = Instance(Pupil, _getDefaultPupil)

    pupil = None

    view = View(Item(
        'basePupil',
        label='Pupil source',
        editor=InstanceEditor(name='pupils', editable=True),
    ),
                Item('_'),
                Item('wavelength'),
                Item('n'),
                Item('NA'),
                Item('pupilSizeX'),
                Item('pixelSize'),
                Item('_'),
                Item('aberations'),
                buttons=[OKButton])

    def GetPupil(self):
        u, v, R, pupil = self.basePupil.GeneratePupil(self.pixelSize,
                                                      self.pupilSizeX,
                                                      self.wavelength, self.NA,
                                                      self.n)
Exemple #8
0
class RootPreferencesHelper(PreferencesHelper):

    # The preferences path for which we use.
    preferences_path = 'enthought.mayavi'

    ######################################################################
    # Our preferences.

    # Specifies if the nodes on the tree may be deleted without a
    # confirmation or not.  If True the user will be prompted before
    # the object is deleted.  If it is False then the user will not be
    # prompted.
    confirm_delete = Bool(desc='if the user is prompted before'
                          ' a node on the MayaVi tree is deleted')

    # Specifies if the splash screen is shown when mayavi starts.
    show_splash_screen = Bool(desc='if the splash screen is shown at'
                              ' startup')

    # Specifies if the adder nodes are shown on the mayavi tree view.
    show_helper_nodes = Bool(desc='if the helper (adder) nodes are shown'
                             ' on the tree view')

    # Specifies if the adder nodes are shown on the mayavi tree view.
    open_help_in_light_browser = Bool(
        desc='if the help pages are opened in a chromeless'
        ' browser window (only works with Firefox)')

    # Contrib directories to load on startup.
    contrib_packages = List(Str, desc='contrib packages to load on startup')

    # Whether or not to use IPython for the Shell.
    use_ipython = Bool(desc='use IPython for the embedded shell '
                       '(if available)')

    ########################################
    # Private traits.
    _contrib_finder = Instance(HasTraits)

    ######################################################################
    # Traits UI view.

    traits_view = View(Group(
        Item(name='confirm_delete'), Item(name='show_splash_screen'),
        Item(name='show_helper_nodes'),
        Item(name='open_help_in_light_browser'),
        Item(
            '_contrib_finder',
            show_label=False,
            editor=InstanceEditor(label='Find contributions'),
        )),
                       resizable=True)

    ######################################################################
    # Non-public interface.
    ######################################################################
    def __contrib_finder_default(self):
        from contrib_finder import ContribFinder
        return ContribFinder()
Exemple #9
0
class Plane(Primitive):
    source = Instance(tvtk.PlaneSource)
    traits_view = View(Item(name='parent', label='Frame'),
                       Item(name='T', label='Matrix4x4', style='custom'),
                       Item(name='properties',
                            editor=InstanceEditor(),
                            label='Render properties'),
                       Item(name='source',
                            editor=InstanceEditor(),
                            label='Geometric properties'),
                       title='Plane properties')

    def __init__(self, *args, **kwargs):
        Primitive.__init__(self, **kwargs)
        self.source = tvtk.PlaneSource()
        self.mapper = tvtk.PolyDataMapper(input=self.source.output)
        self.actor = tvtk.Actor(mapper=self.mapper)
        self.handle_arguments(*args, **kwargs)
Exemple #10
0
class Image(Primitive):
    source = Instance(tvtk.ImageReader)
    file_name = DelegatesTo('source')
    traits_view = View(Item(name='parent', label='Frame'),
                       Item(name='T', label='Matrix4x4', style='custom'),
                       Item(name='file_name'),
                       Item(name='source', editor=InstanceEditor()),
                       Item(name='actor', editor=InstanceEditor()),
                       Item(name='properties',
                            editor=InstanceEditor(),
                            label='Render properties'),
                       title='Image properties')

    def __init__(self, *args, **kwargs):
        Primitive.__init__(self, *kwargs)
        self.source = tvtk.ImageReader(file_name="woodpecker.bmp")  # im.ouput
        self.source.set_data_scalar_type_to_unsigned_char()
        #self.mapper = tvtk.ImageMapper(input=self.source.output)
        #self.actor = tvtk.Actor2D(mapper=self.mapper)
        self.actor = tvtk.ImageActor(input=self.source.output)
        self.handle_arguments(*args, **kwargs)
class MyDemo(HasTraits):
    scene = Instance(SceneModel, ())

    source = Instance(tvtk.ParametricFunctionSource, ())

    func_name = Enum([c.__name__ for c in source_types])

    func = Property(depends_on="func_name")

    traits_view = View(HSplit(
        VGroup(
            Item("func_name", show_label=False),
            Tabbed(
                Item("func",
                     style="custom",
                     editor=InstanceEditor(),
                     show_label=False),
                Item("source", style="custom", show_label=False))),
        Item("scene", style="custom", show_label=False, editor=SceneEditor())),
                       resizable=True,
                       width=700,
                       height=600)

    def __init__(self, *args, **kwds):
        super(MyDemo, self).__init__(*args, **kwds)
        self._make_pipeline()

    def _get_func(self):
        return sources[self.func_name]

    def _make_pipeline(self):
        self.func.on_trait_change(self.on_change, "anytrait")
        src = self.source
        src.on_trait_change(self.on_change, "anytrait")
        src.parametric_function = self.func
        map = tvtk.PolyDataMapper(input_connection=src.output_port)
        act = tvtk.Actor(mapper=map)
        self.scene.add_actor(act)
        self.src = src

    def _func_changed(self, old_func, this_func):
        if old_func is not None:
            old_func.on_trait_change(self.on_change, "anytrait", remove=True)
        this_func.on_trait_change(self.on_change, "anytrait")
        self.src.parametric_function = this_func
        self.scene.render()

    def on_change(self):
        self.scene.render()
Exemple #12
0
class Team(HasStrictTraits):

    name = Str
    captain = Instance(Person)
    roster = List(Person)

    traits_view = View(Item('name'),
                       Item('_'),
                       Item(
                           'captain',
                           label='Team Captain',
                           editor=InstanceEditor(name='roster', editable=True),
                           style='custom',
                       ),
                       buttons=['OK'])
Exemple #13
0
class Text(Primitive):
    text = DelegatesTo('source')
    traits_view = View(Item(name='parent', label='Frame'),
                       Item(name='T', label='Matrix4x4', style='custom'),
                       Item(name='text'),
                       Item(name='properties',
                            editor=InstanceEditor(),
                            label='Render properties'),
                       title='Text properties')

    def __init__(self, *args, **kwargs):
        Primitive.__init__(self, **kwargs)
        self.source = tvtk.VectorText()
        self.mapper = tvtk.PolyDataMapper(input=self.source.get_output())
        self.actor = tvtk.Actor(mapper=self.mapper)
        self.handle_arguments(*args, **kwargs)
Exemple #14
0
class Arrow(Primitive):
    source = Instance(tvtk.ArrowSource)
    tip_resolution = DelegatesTo("source")
    traits_view = View(Item(name='parent', label='Frame'),
                       Item(name='T', label='Matrix4x4', style='custom'),
                       Item(name='tip_resolution'),
                       Item(name='properties',
                            editor=InstanceEditor(),
                            label='Render properties'),
                       title='Arrow properties')

    def __init__(self, *args, **kwargs):
        Primitive.__init__(self, **kwargs)
        self.source = tvtk.ArrowSource()
        self.mapper = tvtk.PolyDataMapper(input=self.source.output)
        self.actor = tvtk.Actor(mapper=self.mapper)
        self.handle_arguments(*args, **kwargs)
Exemple #15
0
class MappingMarchingCubes(TVTKMapperWidget):
    operator = Instance(tvtk.MarchingCubes)
    mapper = Instance(tvtk.HierarchicalPolyDataMapper)
    vmin = Float
    vmax = Float
    auto_set = Bool(False)
    _val_redit = RangeEditor(format="%0.2f",
                             low_name='vmin',
                             high_name='vmax',
                             auto_set=False,
                             enter_set=True)
    traits_view = View(
        Item('value', editor=_val_redit), Item('auto_set'),
        Item('alpha',
             editor=RangeEditor(
                 low=0.0,
                 high=1.0,
                 enter_set=True,
                 auto_set=False,
             )),
        Item('lut_manager',
             show_label=False,
             editor=InstanceEditor(),
             style='custom'))

    def __init__(self, vmin, vmax, vdefault, **traits):
        HasTraits.__init__(self, **traits)
        self.vmin = vmin
        self.vmax = vmax
        trait = Range(float(vmin), float(vmax), value=vdefault)
        self.add_trait("value", trait)
        self.value = vdefault

    def _auto_set_changed(self, old, new):
        if new is True:
            self._val_redit.auto_set = True
            self._val_redit.enter_set = False
        else:
            self._val_redit.auto_set = False
            self._val_redit.enter_set = True

    def _value_changed(self, old, new):
        self.operator.set_value(0, new)
        self.post_call()
Exemple #16
0
class MappingPlane(TVTKMapperWidget):
    plane = Instance(tvtk.Plane)
    _coord_redit = editor = RangeEditor(format="%0.2e",
                                        low_name='vmin',
                                        high_name='vmax',
                                        auto_set=False,
                                        enter_set=True)
    auto_set = Bool(False)
    traits_view = View(
        Item('coord', editor=_coord_redit), Item('auto_set'),
        Item('alpha',
             editor=RangeEditor(low=0.0,
                                high=1.0,
                                enter_set=True,
                                auto_set=False)),
        Item('lut_manager',
             show_label=False,
             editor=InstanceEditor(),
             style='custom'))
    vmin = Float
    vmax = Float

    def _auto_set_changed(self, old, new):
        if new is True:
            self._coord_redit.auto_set = True
            self._coord_redit.enter_set = False
        else:
            self._coord_redit.auto_set = False
            self._coord_redit.enter_set = True

    def __init__(self, vmin, vmax, vdefault, **traits):
        HasTraits.__init__(self, **traits)
        self.vmin = vmin
        self.vmax = vmax
        trait = Range(float(vmin), float(vmax), value=vdefault)
        self.add_trait("coord", trait)
        self.coord = vdefault

    def _coord_changed(self, old, new):
        orig = self.plane.origin[:]
        orig[self.axis] = new
        self.plane.origin = orig
        self.post_call()
Exemple #17
0
class Cylinder(Primitive):
    source = Instance(tvtk.CylinderSource)
    height = DelegatesTo('source')
    radius = DelegatesTo('source')
    resolution = DelegatesTo('source')
    traits_view = View(Item(name='parent', label='Frame'),
                       Item(name='T', label='Matrix4x4', style='custom'),
                       Item(name='height'),
                       Item(name='radius'),
                       Item(name='resolution'),
                       Item(name='properties',
                            editor=InstanceEditor(),
                            label='Render properties'),
                       title='Cylinder properties')

    def __init__(self, *args, **kwargs):
        Primitive.__init__(self, *kwargs)
        self.source = tvtk.CylinderSource()
        self.mapper = tvtk.PolyDataMapper(input=self.source.output)
        self.actor = tvtk.Actor(mapper=self.mapper)
        self.handle_arguments(*args, **kwargs)
Exemple #18
0
class Axes(Primitive):
    source = Instance(tvtk.Axes)
    tube = Instance(tvtk.TubeFilter)

    scale_factor = DelegatesTo('tube')
    radius = DelegatesTo('tube')
    sides = PrototypedFrom('tube', 'number_of_sides')

    traits_view = View(Item(name='parent', label='Frame'),
                       Item(name='T', label='Matrix4x4', style='custom'),
                       Item(name='properties',
                            editor=InstanceEditor(),
                            label='Render properties'),
                       title='Axes properties')

    def __init__(self, *args, **kwargs):
        Primitive.__init__(self, *kwargs)
        self.source = tvtk.Axes(symmetric=1)
        self.tube = tvtk.TubeFilter(vary_radius='vary_radius_off',
                                    input=self.source.output)
        self.mapper = tvtk.PolyDataMapper(input=self.tube.output)
        self.actor = tvtk.Actor(mapper=self.mapper)
        self.handle_arguments(*args, **kwargs)
Exemple #19
0
class Circle(PolyLine):
    radius = Instance(Expression)
    resolution = Int(100)
    traits_view = View(Item(name='parent', label='Frame'),
                       Item(name='T', label='Matrix4x4', style='custom'),
                       Item(name='radius', style='custom'),
                       Item(name='resolution'),
                       Item(name='properties',
                            editor=InstanceEditor(),
                            label='Render properties'),
                       title='Line properties')

    def __init__(self, *args, **kwargs):
        PolyLine.__init__(self, *args, **kwargs)

    def update(self):
        t = linspace(0, 6.29, self.resolution)
        if self.radius.get_curr_value() != None:
            self.points = array([
                self.radius.get_curr_value() * sin(t),
                self.radius.get_curr_value() * cos(t),
                zeros(t.shape)
            ]).T
            super(Circle, self).update()
Exemple #20
0
class Box(Primitive):
    source = Instance(tvtk.CubeSource)
    x_length = DelegatesTo('source')
    y_length = DelegatesTo('source')
    z_length = DelegatesTo('source')

    traits_view = View(Item(name='parent', label='Frame'),
                       Item(name='T', label='Matrix4x4', style='custom'),
                       Item(name='x_length'),
                       Item(name='y_length'),
                       Item(name='z_length'),
                       Item(name='properties',
                            editor=InstanceEditor(),
                            label='Render properties'),
                       title='Box properties')

    def __init__(self, *args, **kwargs):
        Primitive.__init__(self, **kwargs)
        self.source = tvtk.CubeSource()
        self.polyDataMapper = tvtk.PolyDataMapper()
        self.polyDataMapper.input = self.source.output

        self.actor = tvtk.Actor(mapper=self.polyDataMapper)
        self.handle_arguments(*args, **kwargs)
Exemple #21
0
class ScalarCutPlane(Module):

    # The version of this class.  Used for persistence.
    __version__ = 0
    
    # The implicit plane widget used to place the implicit function.
    implicit_plane = Instance(ImplicitPlane, allow_none=False,
                              record=True)
    
    # The cutter.  Takes a cut of the data on the implicit plane.
    cutter = Instance(Cutter, allow_none=False, record=True)
    
    # Specifies if contouring is to be done or not.
    enable_contours = Bool(False, desc='if contours are generated')
    
    # The Contour component that contours the data.
    contour = Instance(Contour, allow_none=False, record=True)
    
    # Specifies if scalar warping is to be done or not.
    enable_warp_scalar = Bool(False, desc='if scalar warping is enabled')
    
    # The WarpScalarCutPlane component that warps the data.
    warp_scalar = Instance(WarpScalar, allow_none=False, record=True)
    
    # Specify if scalar normals are to be computed to make a smoother surface.
    compute_normals = Bool(False, desc='if normals are to be computed '\
                           'to make the warped scalar surface smoother')
    
    # The component that computes the scalar normals.
    normals = Instance(PolyDataNormals, allow_none=False, record=True)
    
    # The actor component that represents the visualization.
    actor = Instance(Actor, allow_none=False, record=True)

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['scalars'])    
    
    ########################################
    # View related code.

    _warp_group = Group(Item(name='filter',
                             style='custom',
                             editor=\
                             InstanceEditor(view=
                                            View(Item('scale_factor')))),
                        show_labels=False)                               
    
    view = View(Group(Item(name='implicit_plane',
                           style='custom'),
                      label='ImplicitPlane',
                      show_labels=False),
                Group(Group(Item(name='enable_contours')),
                      Group(Item(name='contour',
                                 style='custom',
                                 enabled_when='object.enable_contours'),
                            show_labels=False),
                      label='Contours',
                      show_labels=False),
                Group(Item(name='enable_warp_scalar'),
                      Group(Item(name='warp_scalar',
                                 enabled_when='enable_warp_scalar',
                                 style='custom',
                                 editor=InstanceEditor(view=
                                                       View(_warp_group))
                                 ),
                            show_labels=False,
                            ),
                      Item(name='_'),
                      Item(name='compute_normals',
                           enabled_when='enable_warp_scalar'),
                      Item(name='normals',
                           style='custom',
                           show_label=False,
                           enabled_when='compute_normals and enable_warp_scalar'),
                      label='WarpScalar',
                      show_labels=True),
                Group(Item(name='actor',
                           style='custom'),
                      label='Actor',
                      show_labels=False)
                )
    
    ######################################################################
    # `Module` interface
    ######################################################################
    def setup_pipeline(self):
        """Override this method so that it *creates* the tvtk
        pipeline.

        This method is invoked when the object is initialized via
        `__init__`.  Note that at the time this method is called, the
        tvtk data pipeline will *not* yet be setup.  So upstream data
        will not be available.  The idea is that you simply create the
        basic objects and setup those parts of the pipeline not
        dependent on upstream sources and filters.  You should also
        set the `actors` attribute up at this point.
        """
        # Create the objects.
        self.implicit_plane = ImplicitPlane()
        self.cutter = Cutter()        
        self.contour = Contour(auto_contours=True, number_of_contours=10)
        self.warp_scalar = WarpScalar()
        self.normals = PolyDataNormals()
        self.actor = Actor()
        
        # Setup the actor suitably for this module.
        prop = self.actor.property
        prop.line_width = 2.0
        
    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.

        This method is invoked (automatically) when any of the inputs
        sends a `pipeline_changed` event.
        """
        mm = self.module_manager
        if mm is None:
            return
        
        # Data is available, so set the input for the grid plane.
        self.implicit_plane.inputs = [mm.source]
        
        # Ensure that the warped scalar surface's normal is setup right.
        self.warp_scalar.filter.normal = self.implicit_plane.normal
        
        # This makes sure that any changes made to enable_warp when
        # the module is not running are updated when it is started --
        # this in turn calls the other functions (normals and
        # contours) internally.
        self._enable_warp_scalar_changed(self.enable_warp_scalar)
        
        # Set the LUT for the mapper.
        self.actor.set_lut(mm.scalar_lut_manager.lut)
        
        self.pipeline_changed = True
        
    def update_data(self):
        """Override this method so that it flushes the vtk pipeline if
        that is necessary.

        This method is invoked (automatically) when any of the inputs
        sends a `data_changed` event.
        """
        # Just set data_changed, the components should do the rest if
        # they are connected.
        self.data_changed = True

    ######################################################################
    # Non-public methods.
    ######################################################################
    def _get_warp_output(self):
        """Helper function to return the warped (or not) output
        depending on settings.
        """
        if self.enable_warp_scalar:
            if self.compute_normals:
                return self.normals
            else:
                return self.warp_scalar
        else:
            return self.cutter

    def _get_contour_output(self):
        """Helper function to return the contoured (and warped (or
        not)) output depending on settings.
        """
        if self.enable_contours:
            return self.contour
        else:
            return self._get_warp_output()
        
    def _filled_contours_changed_for_contour(self, value):
        """When filled contours are enabled, the mapper should use the
        the cell data, otherwise it should use the default scalar
        mode.
        """
        if value:
            self.actor.mapper.scalar_mode = 'use_cell_data'
        else:
            self.actor.mapper.scalar_mode = 'default'            
        self.render()
    
    def _enable_warp_scalar_changed(self, value):
        """Turns on and off the scalar warping."""
        if self.module_manager is None:
            return

        if value:
            self.warp_scalar.inputs = [self.cutter]
        else:
            self.warp_scalar.inputs = []
        self._compute_normals_changed(self.compute_normals)
        self.render()
    
    def _compute_normals_changed(self, value):
        if self.module_manager is None:
            return

        if self.enable_warp_scalar:
            normals = self.normals
            if value:
                normals.inputs = [self.warp_scalar]
            else:
                normals.inputs = []
        self._enable_contours_changed(self.enable_contours)
        self.render()
    
    def _enable_contours_changed(self, value):
        """Turns on and off the contours."""
        if self.module_manager is None:
            return

        actor = self.actor
        if value:
            self.contour.inputs = [self._get_warp_output()]
            actor.inputs = [self._get_contour_output()]
            if self.contour.filled_contours:
                actor.mapper.scalar_mode = 'use_cell_data'
        else:
            self.contour.inputs = []
            actor.inputs = [self._get_warp_output()]
            actor.mapper.scalar_mode = 'default'
        self.render()
    
    def _normals_changed(self, old, new):
        warp_scalar = self.warp_scalar
        if warp_scalar is not None:
            new.inputs = [warp_scalar]
            self._compute_normals_changed(self.compute_normals)
        self._change_components(old, new)
    
    def _implicit_plane_changed(self, old, new):
        cutter = self.cutter
        if cutter is not None:
            cutter.cut_function = new.plane
            cutter.inputs = [new]
            # Update the pipeline.
            self._enable_warp_scalar_changed(self.enable_warp_scalar)
        # Hook up events to set the normals of the warp filter.
        if old is not None:
            old.widget.on_trait_change(self._update_normal, 'normal', remove=True)
        new.widget.on_trait_change(self._update_normal, 'normal')
        self._change_components(old, new)
    
    def _cutter_changed(self, old, new):
        ip = self.implicit_plane
        if ip is not None:
            new.cut_function = ip.plane
            new.inputs = [ip]
            # Update the pipeline.
            self._enable_warp_scalar_changed(self.enable_warp_scalar)
        self._change_components(old, new)
        
    def _contour_changed(self, old, new):
        # Update the pipeline.
        self._enable_contours_changed(self.enable_contours)
        self._change_components(old, new)
    
    def _warp_scalar_changed(self, old, new):
        # Update the pipeline.
        self._enable_warp_scalar_changed(self.enable_warp_scalar)
        self._change_components(old, new)
    
    def _actor_changed(self, old, new):
        # Update the pipeline.
        self._enable_contours_changed(self.enable_contours)
        self._change_components(old, new)
    
    def _update_normal(self):
        """Invoked when the orientation of the implicit plane changes.
        """
        ws = self.warp_scalar
        if ws is not None:
            ws.filter.normal = self.implicit_plane.widget.normal
Exemple #22
0
class tcWindow(HasTraits):
    project = tcProject
    plot = tcPlot

    def __init__(self, project):
        self.project = project
        self.plot = create_timechart_container(project)
        self.plot_range_tools = self.plot.range_tools
        self.plot_range_tools.on_trait_change(self._selection_time_changed,
                                              "time")
        self.trait_view().title = self.get_title()

    def get_title(self):
        if self.project.filename == "dummy":
            return "PyTimechart: Please Open a File"
        return "PyTimechart:" + self.project.filename

    # Create an action that exits the application.
    status = Str("Welcome to PyTimechart")
    traits_view = View(
        HSplit(
            VSplit(
                Item('project',
                     show_label=False,
                     editor=InstanceEditor(view='process_view'),
                     style='custom',
                     width=150),
                #                Item('plot_range_tools', show_label = False, editor=InstanceEditor(view = 'selection_view'), style='custom',width=150,height=100)
            ),
            Item('plot', show_label=False, editor=ComponentEditor()),
        ),
        toolbar=ToolBar(*_create_toolbar_actions(),
                        image_size=(24, 24),
                        show_tool_names=False),
        menubar=MenuBar(*_create_menubar_actions()),
        statusbar=[
            StatusItem(name='status'),
        ],
        resizable=True,
        width=1280,
        height=1024,
        handler=tcActionHandler())

    def _on_open_trace_file(self):
        if open_file(None) and self.project.filename == "dummy":
            self._ui.dispose()

    def _on_view_properties(self):
        self.plot.options.edit_traits()

    def _on_exit(self, n=None):
        self.close()
        sys.exit(0)

    def close(self, n=None):
        pass

    def _on_about(self):
        aboutBox().edit_traits()

    def _on_doc(self):
        browse_doc()

    def _selection_time_changed(self):
        self.status = "selection time:%s" % (self.plot_range_tools.time)
class WarpVectorCutPlane(Module):

    # The version of this class.  Used for persistence.
    __version__ = 0
    
    # The implicit plane widget used to place the implicit function.
    implicit_plane = Instance(ImplicitPlane, allow_none=False,
                              record=True)
    
    # The cutter.  Takes a cut of the data on the implicit plane.
    cutter = Instance(Cutter, allow_none=False, record=True)
    
    # The WarpVectorCutPlane component that warps the data.
    warp_vector = Instance(WarpVector, allow_none=False, record=True)
    
    # Specify if vector normals are to be computed to make a smoother surface.
    compute_normals = Bool(False, desc='if normals are to be computed '\
                           'to make the warped surface smoother')

    # The component that computes the normals.
    normals = Instance(PolyDataNormals, record=True)
    
    # The Actor component.
    actor = Instance(Actor, allow_none=False, record=True)

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['vectors'])    
    
    ########################################
    # View related traits.

    _warp_group = Group(Item(name='filter',
                             style='custom',
                             editor=\
                             InstanceEditor(view=
                                            View(Item('scale_factor')))),
                        show_labels=False)
    
    view = View(Group(Item(name='implicit_plane', style='custom'),
                      label='ImplicitPlane',
                      show_labels=False),
                Group(Group(Item(name='warp_vector',
                                 style='custom',
                                 resizable=True,
                                 show_label=False,
                                 editor=InstanceEditor(view=View(_warp_group))
                                 ),
                            ),
                      Item(name='_'),
                      Item(name='compute_normals'),
                      Group(Item(name='normals',
                                 style='custom',
                                 show_label=False,
                                 enabled_when = 'compute_normals'),
                            ),
                      label='WarpVector',
                      show_labels=True),
                Group(Item(name='actor', style='custom'),
                      label='Actor',
                      show_labels=False),
                resizable=True,
                )
    
    ######################################################################
    # `Module` interface
    ######################################################################
    def setup_pipeline(self):
        """Override this method so that it *creates* the tvtk
        pipeline.

        This method is invoked when the object is initialized via
        `__init__`.  Note that at the time this method is called, the
        tvtk data pipeline will *not* yet be setup.  So upstream data
        will not be available.  The idea is that you simply create the
        basic objects and setup those parts of the pipeline not
        dependent on upstream sources and filters.  You should also
        set the `actors` attribute up at this point.
        """
        # Create the objects and set them up.
        self.implicit_plane = ImplicitPlane()
        self.cutter = Cutter()
        self.warp_vector = WarpVector()
        self.normals = PolyDataNormals()
        actor = self.actor = Actor()
        actor.mapper.scalar_visibility = 1
    
    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.

        This method is invoked (automatically) when any of the inputs
        sends a `pipeline_changed` event.
        """
        mm = self.module_manager
        if mm is None:
            return
        
        self.implicit_plane.inputs = [mm.source]

        # Force the vector normals setting to be noted.
        self._compute_normals_changed(self.compute_normals)
        
        # Set the LUT for the mapper.
        self.actor.set_lut(mm.scalar_lut_manager.lut)
        
        self.pipeline_changed = True
    
    def update_data(self):
        """Override this method so that it flushes the vtk pipeline if
        that is necessary.

        This method is invoked (automatically) when any of the inputs
        sends a `data_changed` event.
        """
        # Just set data_changed, the other components should do the rest.
        self.data_changed = True
    
    ######################################################################
    # Non-public traits.
    ######################################################################
    def _compute_normals_changed(self, value):
        if self.module_manager is None:
            return
        actor = self.actor
        if actor is not None:
            if value:
                actor.inputs = [self.normals]
            else:
                actor.inputs = [self.warp_vector]
        self.render()
    
    def _normals_changed(self, old, new):
        warp_vector = self.warp_vector
        compute_normals = self.compute_normals
        if compute_normals is not None:
            new.inputs = [warp_vector]
        self._compute_normals_changed(self.compute_normals)
        self._change_components(old, new)
    
    def _implicit_plane_changed(self, old, new):
        cutter = self.cutter
        if cutter is not None:
            cutter.cut_function = new.plane
            cutter.inputs = [new]
        self._change_components(old, new)
    
    def _warp_vector_changed(self, old, new):
        cutter = self.cutter
        if cutter is not None:
            new.inputs = [cutter]
        self._compute_normals_changed(self.compute_normals)
        self._change_components(old, new)
    
    def _cutter_changed(self, old, new):
        ip = self.implicit_plane
        if ip is not None:
            new.cut_function = ip.plane
            new.inputs = [ip]
        w = self.warp_vector
        if w is not None:
            w.inputs = [new]
        self._change_components(old, new)
    
    def _actor_changed(self, old, new):
        self._compute_normals_changed(self.compute_normals)
        self._change_components(old, new)
Exemple #24
0
class Scene(TVTKScene, Widget):
    """A VTK interactor scene widget for pyface and PyQt.

    This widget uses a RenderWindowInteractor and therefore supports
    interaction with VTK widgets.  The widget uses TVTK.  In addition
    to the features that the base TVTKScene provides this widget
    supports:

    - saving the rendered scene to the clipboard.

    - picking data on screen.  Press 'p' or 'P' when the mouse is over
      a point that you need to pick.

    - The widget also uses a light manager to manage the lighting of
      the scene.  Press 'l' or 'L' to activate a GUI configuration
      dialog for the lights.

    - Pressing the left, right, up and down arrow let you rotate the
      camera in those directions.  When shift-arrow is pressed then
      the camera is panned.  Pressing the '+' (or '=')  and '-' keys
      let you zoom in and out.

    - full screen rendering via the full_screen button on the UI.

    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    ###########################################################################
    # Traits.
    ###########################################################################

    # Turn on full-screen rendering.
    full_screen = Button('Full Screen')

    # The picker handles pick events.
    picker = Instance(picker.Picker)

    ########################################

    # Render_window's view.
    _stereo_view = Group(Item(name='stereo_render'),
                         Item(name='stereo_type'),
                         show_border=True,
                         label='Stereo rendering',
                         )

    # The default view of this object.
    default_view = View(Group(
                            Group(Item(name='background'),
                                  Item(name='foreground'),
                                  Item(name='parallel_projection'),
                                  Item(name='disable_render'),
                                  Item(name='off_screen_rendering'),
                                  Item(name='jpeg_quality'),
                                  Item(name='jpeg_progressive'),
                                  Item(name='magnification'),
                                  Item(name='anti_aliasing_frames'),
                                  Item(name='full_screen',
                                       show_label=False),
                                  ),
                            Group(Item(name='render_window',
                                       style='custom',
                                       visible_when='object.stereo',
                                       editor=InstanceEditor(view=View(_stereo_view)),
                                       show_label=False),
                                  ),
                            label='Scene'),
                         Group( Item(name='light_manager',
                                style='custom', show_label=False),
                                label='Lights'),
                         buttons=['OK', 'Cancel']
                        )

    ########################################
    # Private traits.

    _vtk_control = Instance(_VTKRenderWindowInteractor)
    _fullscreen = Any

    ###########################################################################
    # 'object' interface.
    ###########################################################################
    def __init__(self, parent=None, **traits):
        """ Initializes the object. """

        # Base class constructor.
        super(Scene, self).__init__(parent, **traits)

        # Setup the default picker.
        self.picker = picker.Picker(self)

        # The light manager needs creating.
        self.light_manager = None

        self._cursor = QtCore.Qt.ArrowCursor

    def __get_pure_state__(self):
        """Allows us to pickle the scene."""
        # The control attribute is not picklable since it is a VTK
        # object so we remove it.
        d = super(Scene, self).__get_pure_state__()
        for x in ['_vtk_control', '_fullscreen']:
            d.pop(x, None)
        return d

    ###########################################################################
    # 'Scene' interface.
    ###########################################################################
    def render(self):
        """ Force the scene to be rendered. Nothing is done if the
        `disable_render` trait is set to True."""
        if not self.disable_render:
            self._vtk_control.Render()

    def get_size(self):
        """Return size of the render window."""
        sz = self._vtk_control.size()

        return (sz.width(), sz.height())

    def set_size(self, size):
        """Set the size of the window."""
        self._vtk_control.resize(*size)

    def hide_cursor(self):
        """Hide the cursor."""
        self._cursor = self._vtk_control.cursor().shape()
        self._vtk_control.setCursor(QtCore.Qt.BlankCursor)

    def show_cursor(self):
        """Show the cursor."""
        self._vtk_control.setCursor(self._cursor)

    ###########################################################################
    # 'TVTKScene' interface.
    ###########################################################################
    def save_to_clipboard(self):
        """Saves a bitmap of the scene to the clipboard."""
        handler, name = tempfile.mkstemp()
        self.save_bmp(name)
        QtGui.QApplication.clipboard().setImage(QtGui.QImage(name))
        os.close(handler)
        os.unlink(name)

    ###########################################################################
    # Non-public interface.
    ###########################################################################
    def _create_control(self, parent):
        """ Create the toolkit-specific control that represents the widget. """

        # Create the VTK widget.
        self._vtk_control = window = _VTKRenderWindowInteractor(self, parent,
                                                                 stereo=self.stereo)

        # Switch the default interaction style to the trackball one.
        window.GetInteractorStyle().SetCurrentStyleToTrackballCamera()

        # Grab the renderwindow.
        renwin = self._renwin = tvtk.to_tvtk(window.GetRenderWindow())
        renwin.set(point_smoothing=self.point_smoothing,
                   line_smoothing=self.line_smoothing,
                   polygon_smoothing=self.polygon_smoothing)
        # Create a renderer and add it to the renderwindow
        self._renderer = tvtk.Renderer()
        renwin.add_renderer(self._renderer)
        # Save a reference to our camera so it is not GC'd -- needed for
        # the sync_traits to work.
        self._camera = self.camera

        # Sync various traits.
        self._renderer.background = self.background
        self.sync_trait('background', self._renderer)
        self.renderer.on_trait_change(self.render, 'background')
        renwin.off_screen_rendering = self.off_screen_rendering
        self._camera.parallel_projection = self.parallel_projection
        self.sync_trait('parallel_projection', self._camera)
        self.sync_trait('off_screen_rendering', self._renwin)
        self.render_window.on_trait_change(self.render, 'off_screen_rendering')
        self.render_window.on_trait_change(self.render, 'stereo_render')
        self.render_window.on_trait_change(self.render, 'stereo_type')
        self.camera.on_trait_change(self.render, 'parallel_projection')

        self._interactor = tvtk.to_tvtk(window._Iren)

        return window

    def _lift(self):
        """Lift the window to the top. Useful when saving screen to an
        image."""
        if self.render_window.off_screen_rendering:
            # Do nothing if off screen rendering is being used.
            return

        self._vtk_control.window().raise_()
        QtCore.QCoreApplication.processEvents()

    def _full_screen_fired(self):
        fs = self._fullscreen
        if fs is None:
            f = FullScreen(self)
            f.run() # This will block.
            self._fullscreen = None

    def _busy_changed(self, val):
        GUI.set_busy(val)
Exemple #25
0
 
#-------------------------------------------------------------------------------
#  Define the View to use:
#-------------------------------------------------------------------------------

view = View(
    Group(
        [ Item( 'company', 
                editor    = tree_editor, 
                resizable = True ),
          '|<>' ],
        Group( 
            [ '{Employee of the Month}@',
              Item( 'eom@', 
                    editor = InstanceEditor( values = [ 
                                 InstanceDropChoice( klass      = Employee,
                                                     selectable = True ) ] ),
                    resizable = True ),
              '|<>' ],
            [ '{Department of the Month}@',
              Item( 'dom@', 
                    editor = InstanceEditor( values = [ 
                                 InstanceDropChoice( klass = Department ) ] ),
                    resizable = True ),
              '|<>' ],
            show_labels = False,
            layout      = 'split' ),
        orientation = 'horizontal',
        show_labels = False,
        layout      = 'split' ),
    title     = 'Company Structure',
Exemple #26
0
class DataSourceWizardView(DataSourceWizard):

    #----------------------------------------------------------------------
    # Private traits
    #----------------------------------------------------------------------

    _top_label = Str('Describe your data')

    _info_text = Str('Array size do not match')

    _array_label = Str('Available arrays')

    _data_type_text = Str("What does your data represents?" )

    _lines_text = Str("Connect the points with lines" )

    _scalar_data_text = Str("Array giving the value of the scalars")

    _optional_scalar_data_text = Str("Associate scalars with the data points")

    _connectivity_text = Str("Array giving the triangles")

    _vector_data_text = Str("Associate vector components")

    _position_text = Property(depends_on="position_type_")

    _position_text_dict = {'explicit':
                'Coordinnates of the data points:',
                           'orthogonal grid':
                'Position of the layers along each axis:',
            }

    def _get__position_text(self):
        return self._position_text_dict.get(self.position_type_, "")

    _shown_help_text = Str

    _data_sources_wrappers = Property(depends_on='data_sources')

    def _get__data_sources_wrappers(self):
         return [
            ArrayColumnWrapper(name=name, 
                shape=repr(self.data_sources[name].shape))
                    for name in self._data_sources_names
                ]
            

    # A traits pointing to the object, to play well with traitsUI
    _self = Instance(DataSourceWizard)

    _suitable_traits_view = Property(depends_on="data_type_")

    def _get__suitable_traits_view(self):
        return "_%s_data_view" % self.data_type_

    ui = Any(False)

    _preview_button = Button(label='Preview structure')

    def __preview_button_fired(self):
        if self.ui:
            self.build_data_source()
            self.preview()

    _ok_button = Button(label='OK')

    def __ok_button_fired(self):
        if self.ui:
            self.ui.dispose()
            self.build_data_source()


    _cancel_button = Button(label='Cancel')

    def __cancel_button_fired(self):
        if self.ui:
            self.ui.dispose()

    _is_ok = Bool

    _is_not_ok = Bool

    def _anytrait_changed(self):
        """ Validates if the OK button is enabled.
        """
        if self.ui:
            self._is_ok =  self.check_arrays()
            self._is_not_ok = not self._is_ok
    
    _preview_window = Instance(PreviewWindow, ())

    _info_image = Instance(ImageResource, 
                    ImageLibrary.image_resource('@std:alert16',))

    #----------------------------------------------------------------------
    # TraitsUI views
    #----------------------------------------------------------------------

    _coordinates_group = \
                        HGroup(
                           Item('position_x', label='x',
                               editor=EnumEditor(name='_data_sources_names',
                                        invalid='_is_not_ok')), 
                           Item('position_y', label='y',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                           Item('position_z', label='z',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                       )


    _position_group = \
                    Group(
                       Item('position_type'),
                       Group(
                           Item('_position_text', style='readonly',
                                    resizable=False,
                                    show_label=False),
                           _coordinates_group,
                           visible_when='not position_type_=="image data"',
                       ),
                       Group(
                           Item('grid_shape_source_',
                            label='Grid shape',
                            editor=EnumEditor(
                                name='_grid_shape_source_labels',
                                        invalid='_is_not_ok')), 
                           HGroup(
                            spring,
                            Item('grid_shape', style='custom', 
                                    editor=ArrayEditor(width=-60),
                                    show_label=False),
                           enabled_when='grid_shape_source==""',
                            ),
                           visible_when='position_type_=="image data"',
                       ),
                       label='Position of the data points',
                       show_border=True,
                       show_labels=False,
                   ),


    _connectivity_group = \
                   Group(
                       HGroup(
                         Item('_connectivity_text', style='readonly',
                                resizable=False),
                         spring,
                         Item('connectivity_triangles',
                                editor=EnumEditor(name='_data_sources_names'),
                                show_label=False,
                                ),
                         show_labels=False,
                       ),
                       label='Connectivity information',
                       show_border=True,
                       show_labels=False,
                       enabled_when='position_type_=="explicit"',
                   ),


    _scalar_data_group = \
                   Group(
                       Item('_scalar_data_text', style='readonly', 
                           resizable=False,
                           show_label=False),
                       HGroup(
                           spring,
                           Item('scalar_data', 
                               editor=EnumEditor(name='_data_sources_names',
                                        invalid='_is_not_ok')), 
                           show_labels=False,
                           ),
                       label='Scalar value',
                       show_border=True,
                       show_labels=False,
                   )


    _optional_scalar_data_group = \
                   Group(
                       HGroup(
                       'has_scalar_data',
                       Item('_optional_scalar_data_text',
                            resizable=False,
                            style='readonly'),
                       show_labels=False,
                       ),
                       Item('_scalar_data_text', style='readonly', 
                            resizable=False,
                            enabled_when='has_scalar_data',
                           show_label=False),
                       HGroup(
                           spring, 
                           Item('scalar_data', 
                               editor=EnumEditor(name='_data_sources_names',
                                        invalid='_is_not_ok'), 
                               enabled_when='has_scalar_data'),
                           show_labels=False,
                           ),
                       label='Scalar data',
                       show_border=True,
                       show_labels=False,
                   ),


    _vector_data_group = \
                   VGroup(
                       HGroup(
                           Item('vector_u', label='u',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                           Item('vector_v', label='v',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                           Item('vector_w', label='w',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                       ),
                       label='Vector data',
                       show_border=True,
                   ),


    _optional_vector_data_group = \
                   VGroup(
                        HGroup(
                            Item('has_vector_data', show_label=False),
                            Item('_vector_data_text', style='readonly', 
                                resizable=False,
                                show_label=False),
                        ),
                       HGroup(
                           Item('vector_u', label='u',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                           Item('vector_v', label='v',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                           Item('vector_w', label='w',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                           enabled_when='has_vector_data',
                       ),
                       label='Vector data',
                       show_border=True,
                   ),


    _array_view = \
                View(
                    Item('_array_label', editor=TitleEditor(),
                        show_label=False),
                    Group(    
                    Item('_data_sources_wrappers', 
                      editor=TabularEditor(
                          adapter = ArrayColumnAdapter(),
                      ), 
                    ),
                    show_border=True,
                    show_labels=False
                ))

    _questions_view = View(
                Item('_top_label', editor=TitleEditor(),
                        show_label=False),
                HGroup(
                    Item('_data_type_text', style='readonly',
                                resizable=False),
                    spring,
                    'data_type',
                    spring,
                    show_border=True,
                    show_labels=False,
                  ),
                HGroup(
                    Item('_self', style='custom', 
                        editor=InstanceEditor(
                                    view_name='_suitable_traits_view'),
                        ),
                    Group(
                        # FIXME: Giving up on context sensitive help
                        # because of lack of time.
                        #Group(
                        #    Item('_shown_help_text', editor=HTMLEditor(), 
                        #        width=300,
                        #        label='Help',
                        #        ),
                        #    show_labels=False,
                        #    label='Help',
                        #),
                        #Group(
                            Item('_preview_button', 
                                    enabled_when='_is_ok'),
                            Item('_preview_window', style='custom',
                                    label='Preview structure'),
                            show_labels=False,
                            #label='Preview structure',
                        #),
                        #layout='tabbed',
                        #dock='tab',
                    ),
                    show_labels=False,
                    show_border=True,
                ),
            )

    _point_data_view = \
                View(Group(
                   Group(_coordinates_group,
                        label='Position of the data points',
                        show_border=True,
                   ),
                   HGroup(
                       'lines',
                       Item('_lines_text', style='readonly',
                                        resizable=False), 
                       label='Lines',
                       show_labels=False,
                       show_border=True,
                   ),
                   _optional_scalar_data_group,
                   _optional_vector_data_group,
                   # XXX: hack to have more vertical space
                   Label('\n'),
                   Label('\n'),
                   Label('\n'),
                ))


    _surface_data_view = \
                View(Group(
                   _position_group,
                   _connectivity_group,
                   _optional_scalar_data_group,
                   _optional_vector_data_group,
                ))


    _vector_data_view = \
                View(Group(
                   _vector_data_group,
                   _position_group,
                   _optional_scalar_data_group,
                ))


    _volumetric_data_view = \
                View(Group(
                   _scalar_data_group,
                   _position_group,
                   _optional_vector_data_group,
                ))


    _wizard_view = View(
          Group(
            HGroup(
                Item('_self', style='custom', show_label=False,
                     editor=InstanceEditor(view='_array_view'),
                     width=0.17,
                     ),
                '_',
                Item('_self', style='custom', show_label=False,
                     editor=InstanceEditor(view='_questions_view'),
                     ),
                ),
            HGroup(
                Item('_info_image', editor=ImageEditor(),
                    visible_when="_is_not_ok"),
                Item('_info_text', style='readonly', resizable=False,
                    visible_when="_is_not_ok"),
                spring, 
                '_cancel_button', 
                Item('_ok_button', enabled_when='_is_ok'),
                show_labels=False,
            ),
          ),
        title='Import arrays',
        resizable=True,
        )


    #----------------------------------------------------------------------
    # Public interface
    #----------------------------------------------------------------------

    def __init__(self, **traits):
        DataSourceFactory.__init__(self, **traits)
        self._self = self


    def view_wizard(self):
        """ Pops up the view of the wizard, and keeps the reference it to
            be able to close it.
        """
        # FIXME: Workaround for traits bug in enabled_when
        self.position_type_
        self.data_type_
        self._suitable_traits_view
        self.grid_shape_source
        self._is_ok
        self.ui = self.edit_traits(view='_wizard_view')


    def preview(self):
        """ Display a preview of the data structure in the preview
            window.
        """
        self._preview_window.clear()
        self._preview_window.add_source(self.data_source)
        data = lambda name: self.data_sources[name]
        g = Glyph()
        g.glyph.glyph_source.glyph_source = \
                    g.glyph.glyph_source.glyph_list[0]
        g.glyph.scale_mode = 'data_scaling_off'
        if not (self.has_vector_data or self.data_type_ == 'vector'):
            g.glyph.glyph_source.glyph_source.glyph_type = 'cross'
            g.actor.property.representation = 'points'
            g.actor.property.point_size = 3.
        self._preview_window.add_module(g)
        if not self.data_type_ in ('point', 'vector') or self.lines:
            s = Surface()
            s.actor.property.opacity = 0.3
            self._preview_window.add_module(s)
        if not self.data_type_ == 'point':
            self._preview_window.add_filter(ExtractEdges())
            s = Surface()
            s.actor.property.opacity = 0.2
            self._preview_window.add_module(s)
Exemple #27
0
class FitGui(HasTraits):
    """
    This class represents the fitgui application state.
    """

    plot = Instance(Plot)
    colorbar = Instance(ColorBar)
    plotcontainer = Instance(HPlotContainer)
    tmodel = Instance(TraitedModel,allow_none=False)
    nomodel = Property
    newmodel = Button('New Model...')
    fitmodel = Button('Fit Model')
    showerror = Button('Fit Error')
    updatemodelplot = Button('Update Model Plot')
    autoupdate = Bool(True)
    data = Array(dtype=float,shape=(2,None))
    weights = Array
    weighttype = Enum(('custom','equal','lin bins','log bins'))
    weightsvary = Property(Bool)
    weights0rem = Bool(True)
    modelselector = NewModelSelector
    ytype = Enum(('data and model','residuals'))

    zoomtool = Instance(ZoomTool)
    pantool = Instance(PanTool)

    scattertool = Enum(None,'clicktoggle','clicksingle','clickimmediate','lassoadd','lassoremove','lassoinvert')
    selectedi = Property #indecies of the selected objects
    weightchangesel = Button('Set Selection To')
    weightchangeto = Float(1.0)
    delsel = Button('Delete Selected')
    unselectonaction = Bool(True)
    clearsel = Button('Clear Selections')
    lastselaction = Str('None')

    datasymb = Button('Data Symbol...')
    modline = Button('Model Line...')

    savews = Button('Save Weights')
    loadws = Button('Load Weights')
    _savedws = Array

    plotname = Property
    updatestats = Event
    chi2 = Property(Float,depends_on='updatestats')
    chi2r = Property(Float,depends_on='updatestats')


    nmod = Int(1024)
    #modelpanel = View(Label('empty'),kind='subpanel',title='model editor')
    modelpanel = View

    panel_view = View(VGroup(
                       Item('plot', editor=ComponentEditor(),show_label=False),
                       HGroup(Item('tmodel.modelname',show_label=False,style='readonly'),
                              Item('nmod',label='Number of model points'),
                              Item('updatemodelplot',show_label=False,enabled_when='not autoupdate'),
                              Item('autoupdate',label='Auto?'))
                      ),
                    title='Model Data Fitter'
                    )


    selection_view = View(Group(
                           Item('scattertool',label='Selection Mode',
                                 editor=EnumEditor(values={None:'1:No Selection',
                                                           'clicktoggle':'3:Toggle Select',
                                                           'clicksingle':'2:Single Select',
                                                           'clickimmediate':'7:Immediate',
                                                           'lassoadd':'4:Add with Lasso',
                                                           'lassoremove':'5:Remove with Lasso',
                                                           'lassoinvert':'6:Invert with Lasso'})),
                           Item('unselectonaction',label='Clear Selection on Action?'),
                           Item('clearsel',show_label=False),
                           Item('weightchangesel',show_label=False),
                           Item('weightchangeto',label='Weight'),
                           Item('delsel',show_label=False)
                         ),title='Selection Options')

    traits_view = View(VGroup(
                        HGroup(Item('object.plot.index_scale',label='x-scaling',
                                    enabled_when='object.plot.index_mapper.range.low>0 or object.plot.index_scale=="log"'),
                              spring,
                              Item('ytype',label='y-data'),
                              Item('object.plot.value_scale',label='y-scaling',
                                   enabled_when='object.plot.value_mapper.range.low>0 or object.plot.value_scale=="log"')
                              ),
                       Item('plotcontainer', editor=ComponentEditor(),show_label=False),
                       HGroup(VGroup(HGroup(Item('weighttype',label='Weights:'),
                                            Item('savews',show_label=False),
                                            Item('loadws',enabled_when='_savedws',show_label=False)),
                                Item('weights0rem',label='Remove 0-weight points for fit?'),
                                HGroup(Item('newmodel',show_label=False),
                                       Item('fitmodel',show_label=False),
                                       Item('showerror',show_label=False,enabled_when='tmodel.lastfitfailure'),
                                       VGroup(Item('chi2',label='Chi2:',style='readonly',format_str='%6.6g',visible_when='tmodel.model is not None'),
                                             Item('chi2r',label='reduced:',style='readonly',format_str='%6.6g',visible_when='tmodel.model is not None'))
                                       )#Item('selbutton',show_label=False))
                              ,springy=False),spring,
                              VGroup(HGroup(Item('autoupdate',label='Auto?'),
                              Item('updatemodelplot',show_label=False,enabled_when='not autoupdate')),
                              Item('nmod',label='Nmodel'),
                              HGroup(Item('datasymb',show_label=False),Item('modline',show_label=False)),springy=False),springy=True),
                       '_',
                       HGroup(Item('scattertool',label='Selection Mode',
                                 editor=EnumEditor(values={None:'1:No Selection',
                                                           'clicktoggle':'3:Toggle Select',
                                                           'clicksingle':'2:Single Select',
                                                           'clickimmediate':'7:Immediate',
                                                           'lassoadd':'4:Add with Lasso',
                                                           'lassoremove':'5:Remove with Lasso',
                                                           'lassoinvert':'6:Invert with Lasso'})),
                           Item('unselectonaction',label='Clear Selection on Action?'),
                           Item('clearsel',show_label=False),
                           Item('weightchangesel',show_label=False),
                           Item('weightchangeto',label='Weight'),
                           Item('delsel',show_label=False),
                         ),#layout='flow'),
                       Item('tmodel',show_label=False,style='custom',editor=InstanceEditor(kind='subpanel'))
                      ),
                    handler=FGHandler(),
                    resizable=True,
                    title='Data Fitting',
                    buttons=['OK','Cancel'],
                    width=700,
                    height=900
                    )


    def __init__(self,xdata=None,ydata=None,weights=None,model=None,
                 include_models=None,exclude_models=None,fittype=None,**traits):
        """

        :param xdata: the first dimension of the data to be fit
        :type xdata: array-like
        :param ydata: the second dimension of the data to be fit
        :type ydata: array-like
        :param weights:
            The weights to apply to the data. Statistically interpreted as inverse
            errors (*not* inverse variance). May be any of the following forms:

            * None for equal weights
            * an array of points that must match `ydata`
            * a 2-sequence of arrays (xierr,yierr) such that xierr matches the
              `xdata` and yierr matches `ydata`
            * a function called as f(params) that returns an array of weights
              that match one of the above two conditions

        :param model: the initial model to use to fit this data
        :type model:
            None, string, or :class:`pymodelfit.core.FunctionModel1D`
            instance.
        :param include_models:
            With `exclude_models`, specifies which models should be available in
            the "new model" dialog (see `models.list_models` for syntax).
        :param exclude_models:
            With `include_models`, specifies which models should be available in
            the "new model" dialog (see `models.list_models` for syntax).
        :param fittype:
            The fitting technique for the initial fit (see
            :class:`pymodelfit.core.FunctionModel`).
        :type fittype: string

        kwargs are passed in as any additional traits to apply to the
        application.

        """

        self.modelpanel = View(Label('empty'),kind='subpanel',title='model editor')

        self.tmodel = TraitedModel(model)

        if model is not None and fittype is not None:
            self.tmodel.model.fittype = fittype

        if xdata is None or ydata is None:
            if not hasattr(self.tmodel.model,'data') or self.tmodel.model.data is None:
                raise ValueError('data not provided and no data in model')
            if xdata is None:
                xdata = self.tmodel.model.data[0]
            if ydata is None:
                ydata = self.tmodel.model.data[1]
            if weights is None:
                weights = self.tmodel.model.data[2]

        self.on_trait_change(self._paramsChanged,'tmodel.paramchange')

        self.modelselector = NewModelSelector(include_models,exclude_models)

        self.data = [xdata,ydata]


        if weights is None:
            self.weights = np.ones_like(xdata)
            self.weighttype = 'equal'
        else:
            self.weights = np.array(weights,copy=True)
            self.savews = True

        weights1d = self.weights
        while len(weights1d.shape)>1:
            weights1d = np.sum(weights1d**2,axis=0)

        pd = ArrayPlotData(xdata=self.data[0],ydata=self.data[1],weights=weights1d)
        self.plot = plot = Plot(pd,resizable='hv')

        self.scatter = plot.plot(('xdata','ydata','weights'),name='data',
                         color_mapper=_cmapblack if self.weights0rem else _cmap,
                         type='cmap_scatter', marker='circle')[0]

        self.errorplots = None

        if not isinstance(model,FunctionModel1D):
            self.fitmodel = True

        self.updatemodelplot = False #force plot update - generates xmod and ymod
        plot.plot(('xmod','ymod'),name='model',type='line',line_style='dash',color='black',line_width=2)
        del plot.x_mapper.range.sources[-1]  #remove the line plot from the x_mapper source so only the data is tied to the scaling

        self.on_trait_change(self._rangeChanged,'plot.index_mapper.range.updated')

        self.pantool = PanTool(plot,drag_button='left')
        plot.tools.append(self.pantool)
        self.zoomtool = ZoomTool(plot)
        self.zoomtool.prev_state_key = KeySpec('a')
        self.zoomtool.next_state_key = KeySpec('s')
        plot.overlays.append(self.zoomtool)

        self.scattertool = None
        self.scatter.overlays.append(ScatterInspectorOverlay(self.scatter,
                        hover_color = "black",
                        selection_color="black",
                        selection_outline_color="red",
                        selection_line_width=2))


        self.colorbar = colorbar = ColorBar(index_mapper=LinearMapper(range=plot.color_mapper.range),
                                            color_mapper=plot.color_mapper.range,
                                            plot=plot,
                                            orientation='v',
                                            resizable='v',
                                            width = 30,
                                            padding = 5)
        colorbar.padding_top = plot.padding_top
        colorbar.padding_bottom = plot.padding_bottom
        colorbar._axis.title = 'Weights'

        self.plotcontainer = container = HPlotContainer(use_backbuffer=True)
        container.add(plot)
        container.add(colorbar)

        super(FitGui,self).__init__(**traits)

        self.on_trait_change(self._scale_change,'plot.value_scale,plot.index_scale')

        if weights is not None and len(weights)==2:
            self.weightsChanged() #update error bars

    def _weights0rem_changed(self,old,new):
        if new:
            self.plot.color_mapper = _cmapblack(self.plot.color_mapper.range)
        else:
            self.plot.color_mapper = _cmap(self.plot.color_mapper.range)
        self.plot.request_redraw()
#        if old and self.filloverlay in self.plot.overlays:
#            self.plot.overlays.remove(self.filloverlay)
#        if new:
#            self.plot.overlays.append(self.filloverlay)
#        self.plot.request_redraw()

    def _paramsChanged(self):
        self.updatemodelplot = True

    def _nmod_changed(self):
        self.updatemodelplot = True

    def _rangeChanged(self):
        self.updatemodelplot = True

    #@on_trait_change('object.plot.value_scale,object.plot.index_scale',post_init=True)
    def _scale_change(self):
        self.plot.request_redraw()

    def _updatemodelplot_fired(self,new):
        #If the plot has not been generated yet, just skip the update
        if self.plot is None:
            return

        #if False (e.g. button click), update regardless, otherwise check for autoupdate
        if new and not self.autoupdate:
            return

        mod = self.tmodel.model
        if self.ytype == 'data and model':
            if mod:
                #xd = self.data[0]
                #xmod = np.linspace(np.min(xd),np.max(xd),self.nmod)
                xl = self.plot.index_range.low
                xh = self.plot.index_range.high
                if self.plot.index_scale=="log":
                    xmod = np.logspace(np.log10(xl),np.log10(xh),self.nmod)
                else:
                    xmod = np.linspace(xl,xh,self.nmod)
                ymod = self.tmodel.model(xmod)

                self.plot.data.set_data('xmod',xmod)
                self.plot.data.set_data('ymod',ymod)

            else:
                self.plot.data.set_data('xmod',[])
                self.plot.data.set_data('ymod',[])
        elif self.ytype == 'residuals':
            if mod:
                self.plot.data.set_data('xmod',[])
                self.plot.data.set_data('ymod',[])
                #residuals set the ydata instead of setting the model
                res = mod.residuals(*self.data)
                self.plot.data.set_data('ydata',res)
            else:
                self.ytype = 'data and model'
        else:
            assert True,'invalid Enum'


    def _fitmodel_fired(self):
        from warnings import warn

        preaup = self.autoupdate
        try:
            self.autoupdate = False
            xd,yd = self.data
            kwd = {'x':xd,'y':yd}
            if self.weights is not None:
                w = self.weights
                if self.weights0rem:
                    if xd.shape == w.shape:
                        m = w!=0
                        w = w[m]
                        kwd['x'] = kwd['x'][m]
                        kwd['y'] = kwd['y'][m]
                    elif np.any(w==0):
                        warn("can't remove 0-weighted points if weights don't match data")
                kwd['weights'] = w
            self.tmodel.fitdata = kwd
        finally:
            self.autoupdate = preaup

        self.updatemodelplot = True
        self.updatestats = True


#    def _tmodel_changed(self,old,new):
#        #old is only None before it is initialized
#        if new is not None and old is not None and new.model is not None:
#            self.fitmodel = True

    def _newmodel_fired(self,newval):
        from inspect import isclass

        if isinstance(newval,basestring) or isinstance(newval,FunctionModel1D) \
           or (isclass(newval) and issubclass(newval,FunctionModel1D)):
            self.tmodel = TraitedModel(newval)
        else:
            if self.modelselector.edit_traits(kind='modal').result:
                cls = self.modelselector.selectedmodelclass
                if cls is None:
                    self.tmodel = TraitedModel(None)
                elif self.modelselector.isvarargmodel:
                    self.tmodel = TraitedModel(cls(self.modelselector.modelargnum))
                    self.fitmodel = True
                else:
                    self.tmodel = TraitedModel(cls())
                    self.fitmodel = True
            else: #cancelled
                return

    def _showerror_fired(self,evt):
        if self.tmodel.lastfitfailure:
            ex = self.tmodel.lastfitfailure
            dialog = HasTraits(s=ex.__class__.__name__+': '+str(ex))
            view = View(Item('s',style='custom',show_label=False),
                        resizable=True,buttons=['OK'],title='Fitting error message')
            dialog.edit_traits(view=view)

    @cached_property
    def _get_chi2(self):
        try:
            return self.tmodel.model.chi2Data()[0]
        except:
            return 0

    @cached_property
    def _get_chi2r(self):
        try:
            return self.tmodel.model.chi2Data()[1]
        except:
            return 0

    def _get_nomodel(self):
        return self.tmodel.model is None

    def _get_weightsvary(self):
        w = self.weights
        return np.any(w!=w[0])if len(w)>0 else False

    def _get_plotname(self):
        xlabel = self.plot.x_axis.title
        ylabel = self.plot.y_axis.title
        if xlabel == '' and ylabel == '':
            return ''
        else:
            return xlabel+' vs '+ylabel
    def _set_plotname(self,val):
        if isinstance(val,basestring):
            val = val.split('vs')
            if len(val) ==1:
                val = val.split('-')
            val = [v.strip() for v in val]
        self.x_axis.title = val[0]
        self.y_axis.title = val[1]


    #selection-related
    def _scattertool_changed(self,old,new):
        if new == 'No Selection':
            self.plot.tools[0].drag_button='left'
        else:
            self.plot.tools[0].drag_button='right'
        if old is not None and 'lasso' in old:
            if new is not None and 'lasso' in new:
                #connect correct callbacks
                self.lassomode = new.replace('lasso','')
                return
            else:
                #TODO:test
                self.scatter.tools[-1].on_trait_change(self._lasso_handler,
                                            'selection_changed',remove=True)
                del self.scatter.overlays[-1]
                del self.lassomode
        elif old == 'clickimmediate':
            self.scatter.index.on_trait_change(self._immediate_handler,
                                            'metadata_changed',remove=True)

        self.scatter.tools = []
        if new is None:
            pass
        elif 'click' in new:
            smodemap = {'clickimmediate':'single','clicksingle':'single',
                        'clicktoggle':'toggle'}
            self.scatter.tools.append(ScatterInspector(self.scatter,
                                      selection_mode=smodemap[new]))
            if new == 'clickimmediate':
                self.clearsel = True
                self.scatter.index.on_trait_change(self._immediate_handler,
                                                    'metadata_changed')
        elif 'lasso' in new:
            lasso_selection = LassoSelection(component=self.scatter,
                                    selection_datasource=self.scatter.index)
            self.scatter.tools.append(lasso_selection)
            lasso_overlay = LassoOverlay(lasso_selection=lasso_selection,
                                         component=self.scatter)
            self.scatter.overlays.append(lasso_overlay)
            self.lassomode = new.replace('lasso','')
            lasso_selection.on_trait_change(self._lasso_handler,
                                            'selection_changed')
            lasso_selection.on_trait_change(self._lasso_handler,
                                            'selection_completed')
            lasso_selection.on_trait_change(self._lasso_handler,
                                            'updated')
        else:
            raise TraitsError('invalid scattertool value')

    def _weightchangesel_fired(self):
        self.weights[self.selectedi] = self.weightchangeto
        if self.unselectonaction:
            self.clearsel = True

        self._sel_alter_weights()
        self.lastselaction = 'weightchangesel'

    def _delsel_fired(self):
        self.weights[self.selectedi] = 0
        if self.unselectonaction:
            self.clearsel = True

        self._sel_alter_weights()
        self.lastselaction = 'delsel'

    def _sel_alter_weights(self):
        if self.weighttype != 'custom':
            self._customweights = self.weights
            self.weighttype = 'custom'
        self.weightsChanged()

    def _clearsel_fired(self,event):
        if isinstance(event,list):
            self.scatter.index.metadata['selections'] = event
        else:
            self.scatter.index.metadata['selections'] = list()

    def _lasso_handler(self,name,new):
        if name == 'selection_changed':
            lassomask = self.scatter.index.metadata['selection'].astype(int)
            clickmask = np.zeros_like(lassomask)
            clickmask[self.scatter.index.metadata['selections']] = 1

            if self.lassomode == 'add':
                mask = clickmask | lassomask
            elif self.lassomode == 'remove':
                mask = clickmask & ~lassomask
            elif self.lassomode == 'invert':
                mask = np.logical_xor(clickmask,lassomask)
            else:
                raise TraitsError('lassomode is in invalid state')

            self.scatter.index.metadata['selections'] = list(np.where(mask)[0])
        elif name == 'selection_completed':
            self.scatter.overlays[-1].visible = False
        elif name == 'updated':
            self.scatter.overlays[-1].visible = True
        else:
            raise ValueError('traits event name %s invalid'%name)

    def _immediate_handler(self):
        sel = self.selectedi
        if len(sel) > 1:
            self.clearsel = True
            raise TraitsError('selection error in immediate mode - more than 1 selection')
        elif len(sel)==1:
            if self.lastselaction != 'None':
                setattr(self,self.lastselaction,True)
            del sel[0]

    def _savews_fired(self):
        self._savedws = self.weights.copy()

    def _loadws_fired(self):
        self.weights = self._savedws
        self._savews_fired()

    def _get_selectedi(self):
        return self.scatter.index.metadata['selections']


    @on_trait_change('data,ytype',post_init=True)
    def dataChanged(self):
        """
        Updates the application state if the fit data are altered - the GUI will
        know if you give it a new data array, but not if the data is changed
        in-place.
        """
        pd = self.plot.data
        #TODO:make set_data apply to both simultaneously?
        pd.set_data('xdata',self.data[0])
        pd.set_data('ydata',self.data[1])

        self.updatemodelplot = False

    @on_trait_change('weights',post_init=True)
    def weightsChanged(self):
        """
        Updates the application state if the weights/error bars for this model
        are changed - the GUI will automatically do this if you give it a new
        set of weights array, but not if they are changed in-place.
        """
        weights = self.weights
        if 'errorplots' in self.trait_names():
            #TODO:switch this to updating error bar data/visibility changing
            if self.errorplots is not None:
                self.plot.remove(self.errorplots[0])
                self.plot.remove(self.errorplots[1])
                self.errorbarplots = None

            if len(weights.shape)==2 and weights.shape[0]==2:
                xerr,yerr = 1/weights

                high = ArrayDataSource(self.scatter.index.get_data()+xerr)
                low = ArrayDataSource(self.scatter.index.get_data()-xerr)
                ebpx = ErrorBarPlot(orientation='v',
                                   value_high = high,
                                   value_low = low,
                                   index = self.scatter.value,
                                   value = self.scatter.index,
                                   index_mapper = self.scatter.value_mapper,
                                   value_mapper = self.scatter.index_mapper
                                )
                self.plot.add(ebpx)

                high = ArrayDataSource(self.scatter.value.get_data()+yerr)
                low = ArrayDataSource(self.scatter.value.get_data()-yerr)
                ebpy = ErrorBarPlot(value_high = high,
                                   value_low = low,
                                   index = self.scatter.index,
                                   value = self.scatter.value,
                                   index_mapper = self.scatter.index_mapper,
                                   value_mapper = self.scatter.value_mapper
                                )
                self.plot.add(ebpy)

                self.errorplots = (ebpx,ebpy)

        while len(weights.shape)>1:
            weights = np.sum(weights**2,axis=0)
        self.plot.data.set_data('weights',weights)
        self.plot.plots['data'][0].color_mapper.range.refresh()

        if self.weightsvary:
            if self.colorbar not in self.plotcontainer.components:
                self.plotcontainer.add(self.colorbar)
                self.plotcontainer.request_redraw()
        elif self.colorbar in self.plotcontainer.components:
                self.plotcontainer.remove(self.colorbar)
                self.plotcontainer.request_redraw()


    def _weighttype_changed(self, name, old, new):
        if old == 'custom':
            self._customweights = self.weights

        if new == 'custom':
            self.weights = self._customweights #if hasattr(self,'_customweights') else np.ones_like(self.data[0])
        elif new == 'equal':
            self.weights = np.ones_like(self.data[0])
        elif new == 'lin bins':
            self.weights = binned_weights(self.data[0],10,False)
        elif new == 'log bins':
            self.weights = binned_weights(self.data[0],10,True)
        else:
            raise TraitError('Invalid Enum value on weighttype')

    def getModelInitStr(self):
        """
        Generates a python code string that can be used to generate a model with
        parameters matching the model in this :class:`FitGui`.

        :returns: initializer string

        """
        mod = self.tmodel.model
        if mod is None:
            return 'None'
        else:
            parstrs = []
            for p,v in mod.pardict.iteritems():
                parstrs.append(p+'='+str(v))
            if mod.__class__._pars is None: #varargs need to have the first argument give the right number
                varcount = len(mod.params)-len(mod.__class__._statargs)
                parstrs.insert(0,str(varcount))
            return '%s(%s)'%(mod.__class__.__name__,','.join(parstrs))

    def getModelObject(self):
        """
        Gets the underlying object representing the model for this fit.

        :returns: The :class:`pymodelfit.core.FunctionModel1D` object.
        """
        return self.tmodel.model
Exemple #28
0
class SceneModel(TVTKScene):

    ########################################
    # TVTKScene traits.

    light_manager = Property

    picker = Property

    ########################################
    # SceneModel traits.

    # A convenient dictionary based interface to add/remove actors and widgets.
    # This is similar to the interface provided for the ActorEditor.
    actor_map = Dict()

    # This is used primarily to implement the add_actor/remove_actor methods.
    actor_list = List()

    # The actual scene being edited.
    scene_editor = Instance(TVTKScene)

    do_render = Event()

    # Fired when this is activated.
    activated = Event()

    # Fired when this widget is closed.
    closing = Event()

    # This exists just to mirror the TVTKWindow api.
    scene = Property

    ###################################
    # View related traits.

    # Render_window's view.
    _stereo_view = Group(
        Item(name='stereo_render'),
        Item(name='stereo_type'),
        show_border=True,
        label='Stereo rendering',
    )

    # The default view of this object.
    default_view = View(
        Group(Group(
            Item(name='background'),
            Item(name='foreground'),
            Item(name='parallel_projection'),
            Item(name='disable_render'),
            Item(name='off_screen_rendering'),
            Item(name='jpeg_quality'),
            Item(name='jpeg_progressive'),
            Item(name='magnification'),
            Item(name='anti_aliasing_frames'),
        ),
              Group(
                  Item(name='render_window',
                       style='custom',
                       visible_when='object.stereo',
                       editor=InstanceEditor(view=View(_stereo_view)),
                       show_label=False), ),
              label='Scene'),
        Group(Item(name='light_manager',
                   style='custom',
                   editor=InstanceEditor(),
                   show_label=False),
              label='Lights'))

    ###################################
    # Private traits.

    # Used by the editor to determine if the widget was enabled or not.
    enabled_info = Dict()

    def __init__(self, parent=None, **traits):
        """ Initializes the object. """
        # Base class constructor.  We call TVTKScene's super here on purpose.
        # Calling TVTKScene's init will create a new window which we do not
        # want.
        super(TVTKScene, self).__init__(**traits)
        self.control = None

    ######################################################################
    # TVTKScene API.
    ######################################################################
    def render(self):
        """ Force the scene to be rendered. Nothing is done if the
        `disable_render` trait is set to True."""

        self.do_render = True

    def add_actors(self, actors):
        """ Adds a single actor or a tuple or list of actors to the
        renderer."""
        if hasattr(actors, '__iter__'):
            self.actor_list.extend(actors)
        else:
            self.actor_list.append(actors)

    def remove_actors(self, actors):
        """ Removes a single actor or a tuple or list of actors from
        the renderer."""
        my_actors = self.actor_list
        if hasattr(actors, '__iter__'):
            for actor in actors:
                my_actors.remove(actor)
        else:
            my_actors.remove(actors)

    # Conevenience methods.
    add_actor = add_actors
    remove_actor = remove_actors

    def add_widgets(self, widgets, enabled=True):
        """Adds widgets to the renderer.
        """
        if not hasattr(widgets, '__iter__'):
            widgets = [widgets]
        for widget in widgets:
            self.enabled_info[widget] = enabled
        self.add_actors(widgets)

    def remove_widgets(self, widgets):
        """Removes widgets from the renderer."""
        if not hasattr(widgets, '__iter__'):
            widgets = [widgets]
        self.remove_actors(widgets)
        for widget in widgets:
            del self.enabled_info[widget]

    def reset_zoom(self):
        """Reset the camera so everything in the scene fits."""
        if self.scene_editor is not None:
            self.scene_editor.reset_zoom()

    def save(self, file_name, size=None, **kw_args):
        """Saves rendered scene to one of several image formats
        depending on the specified extension of the filename.

        If an additional size (2-tuple) argument is passed the window
        is resized to the specified size in order to produce a
        suitably sized output image.  Please note that when the window
        is resized, the window may be obscured by other widgets and
        the camera zoom is not reset which is likely to produce an
        image that does not reflect what is seen on screen.

        Any extra keyword arguments are passed along to the respective
        image format's save method.
        """
        self._check_scene_editor()
        self.scene_editor.save(file_name, size, **kw_args)

    def save_ps(self, file_name):
        """Saves the rendered scene to a rasterized PostScript image.
        For vector graphics use the save_gl2ps method."""
        self._check_scene_editor()
        self.scene_editor.save_ps(file_name)

    def save_bmp(self, file_name):
        """Save to a BMP image file."""
        self._check_scene_editor()
        self.scene_editor.save_bmp(file_name)

    def save_tiff(self, file_name):
        """Save to a TIFF image file."""
        self._check_scene_editor()
        self.scene_editor.save_tiff(file_name)

    def save_png(self, file_name):
        """Save to a PNG image file."""
        self._check_scene_editor()
        self.scene_editor.save_png(file_name)

    def save_jpg(self, file_name, quality=None, progressive=None):
        """Arguments: file_name if passed will be used, quality is the
        quality of the JPEG(10-100) are valid, the progressive
        arguments toggles progressive jpegs."""
        self._check_scene_editor()
        self.scene_editor.save_jpg(file_name, quality, progressive)

    def save_iv(self, file_name):
        """Save to an OpenInventor file."""
        self._check_scene_editor()
        self.scene_editor.save_iv(file_name)

    def save_vrml(self, file_name):
        """Save to a VRML file."""
        self._check_scene_editor()
        self.scene_editor.save_vrml(file_name)

    def save_oogl(self, file_name):
        """Saves the scene to a Geomview OOGL file. Requires VTK 4 to
        work."""
        self._check_scene_editor()
        self.scene_editor.save_oogl(file_name)

    def save_rib(self, file_name, bg=0, resolution=None, resfactor=1.0):
        """Save scene to a RenderMan RIB file.

        Keyword Arguments:

        file_name -- File name to save to.

        bg -- Optional background option.  If 0 then no background is
        saved.  If non-None then a background is saved.  If left alone
        (defaults to None) it will result in a pop-up window asking
        for yes/no.

        resolution -- Specify the resolution of the generated image in
        the form of a tuple (nx, ny).

        resfactor -- The resolution factor which scales the resolution.
        """
        self._check_scene_editor()
        self.scene_editor.save_rib(file_name, bg, resolution, resfactor)

    def save_wavefront(self, file_name):
        """Save scene to a Wavefront OBJ file.  Two files are
        generated.  One with a .obj extension and another with a .mtl
        extension which contains the material proerties.

        Keyword Arguments:

        file_name -- File name to save to
        """
        self._check_scene_editor()
        self.scene_editor.save_wavefront(file_name)

    def save_gl2ps(self, file_name, exp=None):
        """Save scene to a vector PostScript/EPS/PDF/TeX file using
        GL2PS.  If you choose to use a TeX file then note that only
        the text output is saved to the file.  You will need to save
        the graphics separately.

        Keyword Arguments:

        file_name -- File name to save to.

        exp -- Optionally configured vtkGL2PSExporter object.
        Defaults to None and this will use the default settings with
        the output file type chosen based on the extention of the file
        name.
        """
        self._check_scene_editor()
        self.scene_editor.save_gl2ps(file_name, exp)

    def get_size(self):
        """Return size of the render window."""
        self._check_scene_editor()
        return self.scene_editor.get_size()

    def set_size(self, size):
        """Set the size of the window."""
        self._check_scene_editor()
        self.scene_editor.set_size(size)

    def _update_view(self, x, y, z, vx, vy, vz):
        """Used internally to set the view."""
        if self.scene_editor is not None:
            self.scene_editor._update_view(x, y, z, vx, vy, vz)

    def _check_scene_editor(self):
        if self.scene_editor is None:
            msg = """
            This method requires that there be an active scene editor.
            To do this, you will typically need to invoke::
              object.edit_traits()
            where object is the object that contains the SceneModel.
            """
            raise SceneModelError(msg)

    def _scene_editor_changed(self, old, new):
        if new is None:
            self._renderer = None
            self._renwin = None
            self._interactor = None
        else:
            self._renderer = new._renderer
            self._renwin = new._renwin
            self._interactor = new._interactor

    def _get_picker(self):
        """Getter for the picker."""
        se = self.scene_editor
        if se is not None and hasattr(se, 'picker'):
            return se.picker
        return None

    def _get_light_manager(self):
        """Getter for the light manager."""
        se = self.scene_editor
        if se is not None:
            return se.light_manager
        return None

    ######################################################################
    # SceneModel API.
    ######################################################################
    def _get_scene(self):
        """Getter for the scene property."""
        return self
Exemple #29
0
class GlyphSource(Component):

    # The version of this class.  Used for persistence.
    __version__ = 1

    # Glyph position.  This can be one of ['head', 'tail', 'center'],
    # and indicates the position of the glyph with respect to the
    # input point data.  Please note that this will work correctly
    # only if you do not mess with the source glyph's basic size.  For
    # example if you use a ConeSource and set its height != 1, then the
    # 'head' and 'tail' options will not work correctly.
    glyph_position = Trait('center',
                           TraitPrefixList(['head', 'tail', 'center']),
                           desc='position of glyph w.r.t. data point')

    # The Source to use for the glyph.  This is chosen from
    # `self._glyph_list` or `self.glyph_dict`.
    glyph_source = Instance(tvtk.Object, allow_none=False, record=True)

    # A dict of glyphs to use.
    glyph_dict = Dict(desc='the glyph sources to select from', record=False)

    # A list of predefined glyph sources that can be used.
    glyph_list = Property(List(tvtk.Object), record=False)

    ########################################
    # Private traits.

    # The transformation to use to place glyph appropriately.
    _trfm = Instance(tvtk.TransformFilter, args=())

    # Used for optimization.
    _updating = Bool(False)

    ########################################
    # View related traits.

    view = View(Group(
        Group(Item(name='glyph_position')),
        Group(Item(
            name='glyph_source',
            style='custom',
            resizable=True,
            editor=InstanceEditor(name='glyph_list'),
        ),
              label='Glyph Source',
              show_labels=False)),
                resizable=True)

    ######################################################################
    # `Base` interface
    ######################################################################
    def __get_pure_state__(self):
        d = super(GlyphSource, self).__get_pure_state__()
        for attr in ('_updating', 'glyph_list'):
            d.pop(attr, None)
        return d

    def __set_pure_state__(self, state):
        if 'glyph_dict' in state:
            # Set their state.
            set_state(self, state, first=['glyph_dict'], ignore=['*'])
            ignore = ['glyph_dict']
        else:
            # Set the dict state using the persisted list.
            gd = self.glyph_dict
            gl = self.glyph_list
            handle_children_state(gl, state.glyph_list)
            for g, gs in zip(gl, state.glyph_list):
                name = camel2enthought(g.__class__.__name__)
                if name not in gd:
                    gd[name] = g
                # Set the glyph source's state.
                set_state(g, gs)
            ignore = ['glyph_list']
        g_name = state.glyph_source.__metadata__['class_name']
        name = camel2enthought(g_name)
        # Set the correct glyph_source.
        self.glyph_source = self.glyph_dict[name]
        set_state(self, state, ignore=ignore)

    ######################################################################
    # `Component` interface
    ######################################################################
    def setup_pipeline(self):
        """Override this method so that it *creates* the tvtk
        pipeline.

        This method is invoked when the object is initialized via
        `__init__`.  Note that at the time this method is called, the
        tvtk data pipeline will *not* yet be setup.  So upstream data
        will not be available.  The idea is that you simply create the
        basic objects and setup those parts of the pipeline not
        dependent on upstream sources and filters.  You should also
        set the `actors` attribute up at this point.
        """

        self._trfm.transform = tvtk.Transform()
        # Setup the glyphs.
        self.glyph_source = self.glyph_dict['glyph_source2d']

    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.

        This method is invoked (automatically) when any of the inputs
        sends a `pipeline_changed` event.
        """
        self._glyph_position_changed(self.glyph_position)
        self.pipeline_changed = True

    def update_data(self):
        """Override this method so that it flushes the vtk pipeline if
        that is necessary.

        This method is invoked (automatically) when any of the inputs
        sends a `data_changed` event.
        """
        self.data_changed = True

    def render(self):
        if not self._updating:
            super(GlyphSource, self).render()

    ######################################################################
    # Non-public methods.
    ######################################################################
    def _glyph_source_changed(self, value):
        if self._updating == True:
            return

        gd = self.glyph_dict
        value_cls = camel2enthought(value.__class__.__name__)
        if value not in gd.values():
            gd[value_cls] = value

        # Now change the glyph's source trait.
        self._updating = True
        recorder = self.recorder
        if recorder is not None:
            name = recorder.get_script_id(self)
            lhs = '%s.glyph_source' % name
            rhs = '%s.glyph_dict[%r]' % (name, value_cls)
            recorder.record('%s = %s' % (lhs, rhs))

        name = value.__class__.__name__
        if name == 'GlyphSource2D':
            self.outputs = [value.output]
        else:
            self._trfm.input = value.output
            self.outputs = [self._trfm.output]
        value.on_trait_change(self.render)
        self._updating = False

        # Now update the glyph position since the transformation might
        # be different.
        self._glyph_position_changed(self.glyph_position)

    def _glyph_position_changed(self, value):
        if self._updating == True:
            return

        self._updating = True
        tr = self._trfm.transform
        tr.identity()

        g = self.glyph_source
        name = g.__class__.__name__
        # Compute transformation factor
        if name == 'CubeSource':
            tr_factor = g.x_length / 2.0
        elif name == 'CylinderSource':
            tr_factor = -g.height / 2.0
        elif name == 'ConeSource':
            tr_factor = g.height / 2.0
        elif name == 'SphereSource':
            tr_factor = g.radius
        else:
            tr_factor = 1.
        # Translate the glyph
        if value == 'tail':
            if name == 'GlyphSource2D':
                g.center = 0.5, 0.0, 0.0
            elif name == 'ArrowSource':
                pass
            elif name == 'CylinderSource':
                g.center = 0, tr_factor, 0.0
            elif hasattr(g, 'center'):
                g.center = tr_factor, 0.0, 0.0
        elif value == 'head':
            if name == 'GlyphSource2D':
                g.center = -0.5, 0.0, 0.0
            elif name == 'ArrowSource':
                tr.translate(-1, 0, 0)
            elif name == 'CylinderSource':
                g.center = 0, -tr_factor, 0.0
            else:
                g.center = -tr_factor, 0.0, 0.0
        else:
            if name == 'ArrowSource':
                tr.translate(-0.5, 0, 0)
            elif name != 'Axes':
                g.center = 0.0, 0.0, 0.0

        if name == 'CylinderSource':
            tr.rotate_z(90)

        self._updating = False
        self.render()

    def _get_glyph_list(self):
        # Return the glyph list as per the original order in earlier
        # implementation.
        order = [
            'glyph_source2d', 'arrow_source', 'cone_source', 'cylinder_source',
            'sphere_source', 'cube_source', 'axes'
        ]
        gd = self.glyph_dict
        for key in gd:
            if key not in order:
                order.append(key)
        return [gd[key] for key in order]

    def _glyph_dict_default(self):
        g = {
            'glyph_source2d':
            tvtk.GlyphSource2D(glyph_type='arrow', filled=False),
            'arrow_source':
            tvtk.ArrowSource(),
            'cone_source':
            tvtk.ConeSource(height=1.0, radius=0.2, resolution=15),
            'cylinder_source':
            tvtk.CylinderSource(height=1.0, radius=0.15, resolution=10),
            'sphere_source':
            tvtk.SphereSource(),
            'cube_source':
            tvtk.CubeSource(),
            'axes':
            tvtk.Axes(symmetric=1)
        }
        return g
Exemple #30
0
class SourceWidget(Component):

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The actual poly data source widget.
    widget = Instance(tvtk.ThreeDWidget, record=True)

    # Specifies the updation mode of the poly_data attribute.  There
    # are three modes: 1) 'interactive' -- the poly_data attribute is
    # updated as the widget is interacted with, 2) 'semi-interactive'
    # -- poly_data attribute is updated when the traits of the widget
    # change and when the widget interaction is complete, 3)
    # 'non-interactive' -- poly_data is updated only explicitly at
    # users request by calling `object.update_poly_data`.
    update_mode = Trait(
        'interactive',
        TraitPrefixList(['interactive', 'semi-interactive',
                         'non-interactive']),
        desc='the speed at which the poly data is updated')

    # A list of predefined glyph sources that can be used.
    widget_list = List(tvtk.Object, record=False)

    # The poly data that the widget manages.
    poly_data = Instance(tvtk.PolyData, args=())

    ########################################
    # Private traits.

    _first = Bool(True)
    _busy = Bool(False)
    _unpickling = Bool(False)

    ########################################
    # View related traits.

    view = View(
        Group(
            Item(name='widget',
                 style='custom',
                 resizable=True,
                 editor=InstanceEditor(name='widget_list')),
            label='Source Widget',
            show_labels=False,
        ),
        resizable=True,
    )

    ######################################################################
    # `Base` interface
    ######################################################################
    def __get_pure_state__(self):
        d = super(SourceWidget, self).__get_pure_state__()
        for attr in ('poly_data', '_unpickling', '_first', '_busy'):
            d.pop(attr, None)
        return d

    def __set_pure_state__(self, state):
        self._unpickling = True
        # First create all the allowed widgets in the widget_list attr.
        handle_children_state(self.widget_list, state.widget_list)
        # Now set their state.
        set_state(self, state, first=['widget_list'], ignore=['*'])
        # Set the widget attr depending on value saved.
        m = [x.__class__.__name__ for x in self.widget_list]
        w_c_name = state.widget.__metadata__['class_name']
        w = self.widget = self.widget_list[m.index(w_c_name)]
        # Set the input.
        if len(self.inputs) > 0:
            w.input = self.inputs[0].outputs[0]
        # Fix for the point widget.
        if w_c_name == 'PointWidget':
            w.place_widget()
        # Set state of rest of the attributes ignoring the widget_list.
        set_state(self, state, ignore=['widget_list'])
        # Some widgets need some cajoling to get their setup right.
        w.update_traits()
        if w_c_name == 'PlaneWidget':
            w.origin = state.widget.origin
            w.normal = state.widget.normal
            w.update_placement()
            w.get_poly_data(self.poly_data)
        elif w_c_name == 'SphereWidget':
            # XXX: This hack is necessary because the sphere widget
            # does not update its poly data even when its ivars are
            # set (plus it does not have an update_placement method
            # which is a bug).  So we force this by creating a similar
            # sphere source and copy its output.
            s = tvtk.SphereSource(center=w.center,
                                  radius=w.radius,
                                  theta_resolution=w.theta_resolution,
                                  phi_resolution=w.phi_resolution,
                                  lat_long_tessellation=True)
            s.update()
            self.poly_data.shallow_copy(s.output)
        else:
            w.get_poly_data(self.poly_data)
        self._unpickling = False
        # Set the widgets trait so that the widget is rendered if needed.
        self.widgets = [w]

    ######################################################################
    # `Component` interface
    ######################################################################
    def setup_pipeline(self):
        """Override this method so that it *creates* the tvtk
        pipeline.

        This method is invoked when the object is initialized via
        `__init__`.  Note that at the time this method is called, the
        tvtk data pipeline will *not* yet be setup.  So upstream data
        will not be available.  The idea is that you simply create the
        basic objects and setup those parts of the pipeline not
        dependent on upstream sources and filters.  You should also
        set the `actors` attribute up at this point.
        """
        # Setup the glyphs.
        sources = [
            tvtk.SphereWidget(theta_resolution=8, phi_resolution=6),
            tvtk.LineWidget(clamp_to_bounds=False),
            tvtk.PlaneWidget(),
            tvtk.PointWidget(outline=False,
                             x_shadows=False,
                             y_shadows=False,
                             z_shadows=False),
        ]
        self.widget_list = sources
        # The 'widgets' trait is set in the '_widget_changed' handler.
        self.widget = sources[0]

        for s in sources:
            self._connect(s)

    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.

        This method is invoked (automatically) when any of the inputs
        sends a `pipeline_changed` event.
        """
        if len(self.inputs) == 0:
            return
        inp = self.inputs[0].outputs[0]
        w = self.widget
        w.input = inp

        if self._first:
            w.place_widget()
            self._first = False

        # If the dataset is effectively 2D switch to using the line
        # widget since that works best.
        b = inp.bounds
        l = [(b[1] - b[0]), (b[3] - b[2]), (b[5] - b[4])]
        max_l = max(l)
        for i, x in enumerate(l):
            if x / max_l < 1.0e-6:
                w = self.widget = self.widget_list[1]
                w.clamp_to_bounds = True
                w.align = ['z_axis', 'z_axis', 'y_axis'][i]
                break

        # Set our output.
        w.get_poly_data(self.poly_data)
        self.outputs = [self.poly_data]

        self.pipeline_changed = True

    def update_data(self):
        """Override this method so that it flushes the vtk pipeline if
        that is necessary.

        This method is invoked (automatically) when any of the inputs
        sends a `data_changed` event.
        """
        self.data_changed = True

    ######################################################################
    # `SourceWidget` interface
    ######################################################################
    def update_poly_data(self):
        self.widget.get_poly_data(self.poly_data)

    ######################################################################
    # Non-public traits.
    ######################################################################
    def _widget_changed(self, value):
        # If we are being unpickled do nothing.
        if self._unpickling:
            return
        if value not in self.widget_list:
            classes = [o.__class__ for o in self.widget_list]
            vc = value.__class__
            self._connect(value)
            if vc in classes:
                self.widget_list[classes.index(vc)] = value
            else:
                self.widget_list.append(value)

        recorder = self.recorder
        if recorder is not None:
            idx = self.widget_list.index(value)
            name = recorder.get_script_id(self)
            lhs = '%s.widget' % name
            rhs = '%s.widget_list[%d]' % (name, idx)
            recorder.record('%s = %s' % (lhs, rhs))

        if len(self.inputs) > 0:
            value.input = self.inputs[0].outputs[0]
            value.place_widget()

        value.on_trait_change(self.render)
        self.widgets = [value]

    def _update_mode_changed(self, value):
        if value in ['interactive', 'semi-interactive']:
            self.update_poly_data()
            self.render()

    def _on_interaction_event(self, obj, event):
        if (not self._busy) and (self.update_mode == 'interactive'):
            self._busy = True
            self.update_poly_data()
            self._busy = False

    def _on_widget_trait_changed(self):
        if (not self._busy) and (self.update_mode != 'non-interactive'):
            self._busy = True
            # This render call forces any changes to the trait to be
            # rendered only then will updating the poly data make
            # sense.
            self.render()
            self.update_poly_data()
            self._busy = False

    def _on_alignment_set(self):
        w = self.widget
        w.place_widget()
        w.update_traits()

    def _connect(self, obj):
        """Wires up all the event handlers."""
        obj.add_observer('InteractionEvent', self._on_interaction_event)
        if isinstance(obj, tvtk.PlaneWidget):
            obj.on_trait_change(self._on_alignment_set, 'normal_to_x_axis')
            obj.on_trait_change(self._on_alignment_set, 'normal_to_y_axis')
            obj.on_trait_change(self._on_alignment_set, 'normal_to_z_axis')
        elif isinstance(obj, tvtk.LineWidget):
            obj.on_trait_change(self._on_alignment_set, 'align')

        # Setup the widgets colors.
        fg = (1, 1, 1)
        if self.scene is not None:
            fg = self.scene.foreground
        self._setup_widget_colors(obj, fg)

        obj.on_trait_change(self._on_widget_trait_changed)
        obj.on_trait_change(self.render)

    def _setup_widget_colors(self, widget, color):
        trait_names = widget.trait_names()
        props = [
            x for x in trait_names if 'property' in x and 'selected' not in x
        ]
        sel_props = [
            x for x in trait_names if 'property' in x and 'selected' in x
        ]
        for p in props:
            setattr(getattr(widget, p), 'color', color)
            setattr(getattr(widget, p), 'line_width', 2)
        for p in sel_props:
            # Set the selected color to 'red'.
            setattr(getattr(widget, p), 'color', (1, 0, 0))
            setattr(getattr(widget, p), 'line_width', 2)
        self.render()

    def _foreground_changed_for_scene(self, old, new):
        # Change the default color for the actor.
        for w in self.widget_list:
            self._setup_widget_colors(w, new)
        self.render()

    def _scene_changed(self, old, new):
        super(SourceWidget, self)._scene_changed(old, new)
        self._foreground_changed_for_scene(None, new.foreground)