def __init__(self, klass = object, allow_none = True, factory = None, args = None, kw = None, **metadata): try: iszopeiface = issubclass(klass, zope.interface.Interface) except TypeError: iszopeiface = False metadata.setdefault( 'copy', 'deep' ) self._allow_none = allow_none self.klass = klass default_value = None if has_interface(klass, IContainer) or (isclass(klass) and IContainer.implementedBy(klass)): self._is_container = True else: self._is_container = False if iszopeiface: self._instance = None self.factory = factory self.args = args self.kw = kw else: self._instance = Instance(klass=klass, allow_none=allow_none, factory=factory, args=args, kw=kw, **metadata) default_value = self._instance.default_value super(Slot, self).__init__(default_value, **metadata)
class Slot(Variable): """A trait for an object of a particular type or implementing a particular interface. Both Traits Interfaces and zope.interface.Interfaces are supported. """ def __init__(self, klass = None, allow_none = True, factory = None, args = None, kw = None, **metadata): try: iszopeiface = issubclass(klass, zope.interface.Interface) except TypeError: iszopeiface = False metadata.setdefault( 'copy', 'deep' ) self._allow_none = allow_none self.klass = klass default_value = None if iszopeiface: self._instance = None self.factory = factory self.args = args self.kw = kw else: self._instance = Instance(klass=klass, allow_none=allow_none, factory=factory, args=args, kw=kw, **metadata) default_value = self._instance.default_value super(Slot, self).__init__(default_value, **metadata) def validate ( self, obj, name, value ): """ Validates that the value is a valid object instance.""" if value is None: if self._allow_none: return value self.validate_failed( obj, name, value ) if self._instance is None: # our iface is a zope.interface if self.klass.providedBy(value): return value else: self._iface_error(obj, name, self.klass.__name__) else: try: return self._instance.validate(obj, name, value) except Exception: if issubclass(self._instance.klass, Interface): self._iface_error(obj, name, self._instance.klass.__name__) else: obj.raise_exception("%s must be an instance of class '%s'" % (name, self._instance.klass.__name__), TypeError) def _iface_error(self, obj, name, iface_name): obj.raise_exception("%s must provide interface '%s'" % (name, iface_name), TypeError)
def __init__(self, klass = None, allow_none = True, factory = None, args = None, kw = None, **metadata): try: iszopeiface = issubclass(klass, zope.interface.Interface) except TypeError: iszopeiface = False metadata.setdefault( 'copy', 'deep' ) self._allow_none = allow_none self.klass = klass default_value = None if iszopeiface: self._instance = None self.factory = factory self.args = args self.kw = kw else: self._instance = Instance(klass=klass, allow_none=allow_none, factory=factory, args=args, kw=kw, **metadata) default_value = self._instance.default_value super(Slot, self).__init__(default_value, **metadata)
class Collection(Filter): """ Defines a Collection filter which is a collection of mayavi filters/components bundled into one. """ # The filters we manage. filters = List(Instance(PipelineBase), record=True) ######################################## # Private traits. # Is the pipeline ready? Used internally. _pipeline_ready = Bool(False) ###################################################################### # `object` interface. ###################################################################### def __set_pure_state__(self, state): # Create and set the filters. handle_children_state(self.filters, state.filters) # Restore our state using the super class method. super(Collection, self).__set_pure_state__(state) ###################################################################### # HasTraits interface. ###################################################################### def default_traits_view(self): """Returns the default traits view for this object.""" le = ListEditor(use_notebook=True, deletable=False, export='DockWindowShell', page_name='.name') view = View(Group(Item(name='filters', style='custom', show_label=False, editor=le, resizable=True), show_labels=False), resizable=True) return view ###################################################################### # `Filter` interface. ###################################################################### def setup_pipeline(self): """Setup the pipeline.""" # Needed because a user may have defined the filters by setting # the default value of the trait in the subclass in which case # the filters_changed handler will never be called leading to # problems. if len(self.filters) > 0 and not self._pipeline_ready: self._filters_changed([], self.filters) def stop(self): # There is no need to override start since the wrapped filters # are always started automatically in the filters_changed # handler. super(Collection, self).stop() for filter in self.filters: filter.stop() def update_pipeline(self): """This method *updates* the tvtk pipeline when data upstream is known to have changed. This method is invoked (automatically) when the input fires a `pipeline_changed` event. """ self._setup_pipeline() # Propagate the event. self.pipeline_changed = True def update_data(self): """This method does what is necessary when upstream data changes. This method is invoked (automatically) when any of the inputs sends a `data_changed` event. """ # Propagate the data_changed event. self.data_changed = True ###################################################################### # Private interface. ###################################################################### def _setup_pipeline(self): """Sets up the objects in the pipeline.""" if len(self.inputs) == 0 or len(self.filters) == 0: return # Our input. my_input = self.inputs[0] filters = self.filters if not self._pipeline_ready: # Hook up our first filter. first = self.filters[0] first.inputs = [my_input] # Hook up the others to each other. for i in range(1, len(filters)): filter = filters[i] filter.inputs = [filters[i - 1]] self._pipeline_ready = True # Start filters. for filter in filters: filter.start() # Set our outputs last = filters[-1] self._set_outputs(last.outputs) def _filters_changed(self, old, new): """Static traits handler.""" self._handle_filters_changed(old, new) def _filters_items_changed(self, list_event): """Static traits handler.""" self._handle_filters_changed(list_event.removed, list_event.added) def _scene_changed(self, old, new): """Static traits handler.""" for filter in self.filters: filter.scene = new super(Collection, self)._scene_changed(old, new) def _handle_filters_changed(self, removed, added): for filter in removed: self._setup_events(filter, remove=True) filter.stop() for filter in added: if self.scene is not None: filter.scene = self.scene if len(filter.name) == 0: filter.name = filter.__class__.__name__ if filter is self.filters[-1]: self._setup_events(filter) self._pipeline_ready = False self._setup_pipeline() def _fire_pipeline_changed(self): # When the last filter fires a pipeline changed we should reset # our outputs to that of its outputs. Calling _setup_pipeline # is expensive and will cause a recursion error. self._set_outputs(self.filters[-1].outputs) def _setup_events(self, obj, remove=False): obj.on_trait_change(self.update_data, 'data_changed', remove=remove) obj.on_trait_change(self._fire_pipeline_changed, 'pipeline_changed', remove=remove) def _visible_changed(self, value): for filter in self.filters: filter.visible = value super(Collection, self)._visible_changed(value) def _recorder_changed(self, old, new): super(Collection, self)._recorder_changed(old, new) for filter in self.filters: filter.recorder = new
class AbstractAdapterFactory(HasTraits): """ Abstract base class for all adapter factories. Adapter factories define behavioural extensions for classes. """ #### 'AbstractAdapterFactory' interface ################################### # The adapter manager that the factory is registered with (this will be # None iff the factory is not registered with a manager). adapter_manager = Instance(AdapterManager) # The type system used by the factory (it determines 'is_a' relationships # and type MROs etc). By default we use standard Python semantics. type_system = Delegate('adapter_manager') ########################################################################### # 'AbstractAdapterFactory' interface. ########################################################################### def adapt(self, adaptee, target_class, *args, **kw): """ Returns an adapter that adapts an object to the target class. Returns None if the factory cannot produce such an adapter. """ if self._can_adapt(adaptee, target_class, *args, **kw): adapter = self._adapt(adaptee, target_class, *args, **kw) if adapter is None: logger.warn(self._get_warning_message(adaptee, target_class)) else: adapter = None return adapter ########################################################################### # Protected 'AbstractAdapterFactory' interface. ########################################################################### def _can_adapt(self, adaptee, target_class, *args, **kw): """ Returns True if the factory can produce an appropriate adapter. """ raise NotImplementedError def _adapt(self, adaptee, target_class, *args, **kw): """ Returns an adapter that adapts an object to the target class. """ raise NotImplementedError ########################################################################### # Private interface. ########################################################################### def _get_warning_message(self, adaptee, target_class): """ Returns a warning message. The warning message is used when a factory fails to adapt something that it said it could! """ message = '%s failed to adapt %s to %s' % ( self.__class__.__name__, str(adaptee), target_class.__name__) return message
class StarComponent(Component): stars = List(Star) star_color = Color((255, 255, 255)) edges = Range(3, 10, 5) sx = Float # 移动开始时的星星中心X坐标 sy = Float # 移动开始时的星星中心Y坐标 mx = Float # 移动开始时的鼠标X坐标 my = Float # 移动开始时的鼠标Y坐标 moving_star = Instance(Star) event_state = Enum("normal", "drawing", "moving") def normal_left_down(self, event): "添加一个Star对象进stars列表,并切换到drawing状态" self.stars.append( Star(x=event.x, y=event.y, r=0, theta=0, n=self.edges, s=0.5, c=convert_color(self.star_color))) self.event_state = "drawing" self.request_redraw() def drawing_mouse_move(self, event): "修改stars中最后一个Star对象的半径和起始角度" star = self.stars[-1] star.r = np.sqrt((event.x - star.x)**2 + (event.y - star.y)**2) star.theta = np.arctan2(event.y - star.y, event.x - star.x) self.request_redraw() def drawing_left_up(self, event): "完成一个星形的绘制,回到normal状态" self.event_state = "normal" def normal_mouse_wheel(self, event): "找到包含鼠标坐标的星形,并修改其半径比例" star = self.find_star(event.x, event.y) if star is not None: star.s += event.mouse_wheel * 0.02 if star.s < 0.05: star.s = 0.05 self.request_redraw() def normal_right_down(self, event): "找到包含鼠标坐标的星形,用moving_star属性保存它,并进入moving状态" star = self.find_star(event.x, event.y) if star is not None: self.mx, self.my = event.x, event.y # 记录鼠标位置 self.sx, self.sy = star.x, star.y # 记录星形的中心位置 self.moving_star = star self.event_state = "moving" def moving_mouse_move(self, event): "修改moving_star的x,y坐标,实现星形的移动" self.moving_star.x = self.sx + event.x - self.mx self.moving_star.y = self.sy + event.y - self.my self.request_redraw() def moving_right_up(self, event): "移动操作结束,回到normal状态" self.event_state = "normal" def _draw_overlay(self, gc, view_bounds=None, mode="normal"): gc.clear((0, 0, 0, 1)) #填充为全黑 gc.save_state() for star in self.stars: draw_star(gc, star.x, star.y, star.r, star.c, star.theta, star.n, star.s) gc.draw_path() gc.restore_state() def find_star(self, x, y): from enthought.kiva.agg import points_in_polygon for star in self.stars[::-1]: if points_in_polygon((x, y), star.polygon()): return star return None
class CustomSaveTool(SaveTool): """ This tool allows the user to press Ctrl+S to save a snapshot image of the plot component. """ # This hack was done because calling self.configure_traits in normal_key_pressed # causes corruption in saved image and also the ui class FileName(HasTraits): filename = File("saved_plot.png") # The file that the image is saved in. The format will be deduced from # the extension. filenameview = Instance(FileName, FileName()) filename = DelegatesTo('filenameview') #------------------------------------------------------------------------- # PDF format options # This mirror the traits in PdfPlotGraphicsContext. #------------------------------------------------------------------------- pagesize = Enum("letter", "A4") dest_box = Tuple((0.5, 0.5, -0.5, -0.5)) dest_box_units = Enum("inch", "cm", "mm", "pica") #------------------------------------------------------------------------- # Override default trait values inherited from BaseTool #------------------------------------------------------------------------- # This tool does not have a visual representation (overrides BaseTool). draw_mode = "none" # This tool is not visible (overrides BaseTool). visible = False def normal_key_pressed(self, event): """ Handles a key-press when the tool is in the 'normal' state. Saves an image of the plot if the keys pressed are Control and S. """ if self.component is None: return if event.character == "s" and event.control_down: if self.filenameview.configure_traits(view=View( Item('filename', editor=FileEditor(entries=0, filter=[ 'PNG file (*.png)|*.png', 'GIF file (*.gif)|*.gif', 'JPG file (*.jpg)|*.jpg', 'JPEG file (*.jpeg)|*.jpeg', 'PDF file (*.pdf)|*.pdf' ])), buttons=['OK', 'Cancel']), kind='modal'): if os.path.splitext(self.filename)[-1] == ".pdf": self._save_pdf() else: self._save_raster() event.handled = True return def _save_raster(self): """ Saves an image of the component. """ from enthought.chaco.api import PlotGraphicsContext gc = PlotGraphicsContext((int(self.component.outer_width), int(self.component.outer_height))) self.component.draw(gc, mode="normal") gc.save(self.filename) return def _save_pdf(self): from enthought.chaco.pdf_graphics_context import PdfPlotGraphicsContext gc = PdfPlotGraphicsContext(filename=self.filename, pagesize=self.pagesize, dest_box=self.dest_box, dest_box_units=self.dest_box_units) gc.render_component(self.component, container_coords=True) gc.save()
class ViewportDefiner(HasTraits): width = traits.Int height = traits.Int display_name = traits.String plot = Instance(Component) linedraw = Instance(LineSegmentTool) viewport_id = traits.String display_mode = traits.Trait('white on black', 'black on white') display_server = traits.Any display_info = traits.Any show_grid = traits.Bool traits_view = View( Group(Item('display_mode'), Item('display_name'), Item('viewport_id'), Item('plot', editor=ComponentEditor(), show_label=False), orientation="vertical"), resizable=True, ) def __init__(self, *args, **kwargs): super(ViewportDefiner, self).__init__(*args, **kwargs) #find our index in the viewport list viewport_ids = [] self.viewport_idx = -1 for i, obj in enumerate(self.display_info['virtualDisplays']): viewport_ids.append(obj['id']) if obj['id'] == self.viewport_id: self.viewport_idx = i if self.viewport_idx == -1: raise Exception("Could not find viewport (available ids: %s)" % ",".join(viewport_ids)) self._update_image() self.fqdn = self.display_name + '/display/virtualDisplays' self.this_virtual_display = self.display_info['virtualDisplays'][ self.viewport_idx] all_points_ok = True # error check for (x, y) in self.this_virtual_display['viewport']: if (x >= self.width) or (y >= self.height): all_points_ok = False break if all_points_ok: self.linedraw.points = self.this_virtual_display['viewport'] else: self.linedraw.points = [] rospy.logwarn('invalid points') self._update_image() def _update_image(self): self._image = np.zeros((self.height, self.width, 3), dtype=np.uint8) fill_polygon.fill_polygon(self.linedraw.points, self._image) if self.show_grid: # draw red horizontal stripes for i in range(0, self.height, 100): self._image[i:i + 10, :, 0] = 255 # draw blue vertical stripes for i in range(0, self.width, 100): self._image[:, i:i + 10, 2] = 255 if hasattr(self, '_pd'): self._pd.set_data("imagedata", self._image) self.send_array() if len(self.linedraw.points) >= 3: self.update_ROS_params() def _plot_default(self): self._pd = ArrayPlotData() self._pd.set_data("imagedata", self._image) plot = Plot(self._pd, default_origin="top left") plot.x_axis.orientation = "top" img_plot = plot.img_plot("imagedata")[0] plot.bgcolor = "white" # Tweak some of the plot properties plot.title = "Click to add points, press Enter to clear selection" plot.padding = 50 plot.line_width = 1 # Attach some tools to the plot pan = PanTool(plot, drag_button="right", constrain_key="shift") plot.tools.append(pan) zoom = ZoomTool(component=plot, tool_mode="box", always_on=False) plot.overlays.append(zoom) return plot def _linedraw_default(self): linedraw = LineSegmentTool(self.plot, color=(0.5, 0.5, 0.9, 1.0)) self.plot.overlays.append(linedraw) linedraw.on_trait_change(self.points_changed, 'points[]') return linedraw def points_changed(self): self._update_image() @traits.on_trait_change('display_mode') def send_array(self): # create an array if self.display_mode.endswith(' on black'): bgcolor = (0, 0, 0, 1) elif self.display_mode.endswith(' on white'): bgcolor = (1, 1, 1, 1) if self.display_mode.startswith('black '): color = (0, 0, 0, 1) elif self.display_mode.startswith('white '): color = (1, 1, 1, 1) self.display_server.show_pixels(self._image) def get_viewport_verts(self): # convert to integers pts = [(fill_polygon.posint(x, self.width - 1), fill_polygon.posint(y, self.height - 1)) for (x, y) in self.linedraw.points] # convert to list of lists for maximal json compatibility return [list(x) for x in pts] def update_ROS_params(self): viewport_verts = self.get_viewport_verts() self.this_virtual_display['viewport'] = viewport_verts self.display_info['virtualDisplays'][ self.viewport_idx] = self.this_virtual_display rospy.set_param(self.fqdn, self.display_info['virtualDisplays'])
class CViewerPreferenceManager(PreferenceManager): # add the cviewer ui preferences cviewerui = Instance(PreferencesHelper) # The preferences. preferences = Instance(IPreferences) ###################################################################### # Traits UI view. traits_view = View(Group( Group(Item(name='root', style='custom'), show_labels=False, label='Root', show_border=True ), Group(Item(name='mlab', style='custom'), show_labels=False, label='Mlab', show_border=True, ), Group(Item(name='cviewerui', style='custom'), show_labels=False, label='ConnectomeViewer', show_border=True ) ), buttons=['OK', 'Cancel'], resizable=True ) ###################################################################### # `HasTraits` interface. ###################################################################### def __init__(self, **traits): super(PreferenceManager, self).__init__(**traits) if 'preferences' not in traits: self._load_preferences() def _preferences_default(self): return ScopedPreferences() def _cviewerui_default(self): """Trait initializer.""" return CViewerUIPreferencesHelper(preferences=self.preferences) def _load_preferences(self): """Load the default preferences.""" # Save current application_home. app_home = ETSConfig.get_application_home() logger.debug('Application home: ' + str(app_home)) # Set it to where the cviewer preferences are temporarily. path = join(ETSConfig.get_application_data(), ID) ETSConfig.application_home = path try: for pkg in ('cviewer.plugins.ui', 'enthought.mayavi.preferences', 'enthought.tvtk.plugins.scene'): pref = 'preferences.ini' pref_file = pkg_resources.resource_stream(pkg, pref) preferences = self.preferences default = preferences.node('default/') default.load(pref_file) pref_file.close() finally: # Set back the application home. ETSConfig.application_home = app_home def _preferences_changed(self, preferences): """Setup the helpers if the preferences trait changes.""" for helper in (self.root, ): helper.preferences = preferences
class SetActiveAttribute(Filter): """ This filter lets a user set the active data attribute (scalars, vectors and tensors) on a VTK dataset. This is particularly useful if you need to do something like compute contours of one scalar on the contour of another scalar. """ # Note: most of this code is from the XMLFileDataReader. # The version of this class. Used for persistence. __version__ = 0 input_info = PipelineInfo(datasets=['any'], attribute_types=['any'], attributes=['any']) output_info = PipelineInfo(datasets=['any'], attribute_types=['any'], attributes=['any']) ######################################## # Dynamic traits: These traits are dynamic and are automatically # updated depending on the contents of the file. # The active point scalar name. An empty string indicates that # the attribute is "deactivated". This is useful when you have # both point and cell attributes and want to use cell data by # default. point_scalars_name = DEnum(values_name='_point_scalars_list', desc='scalar point data attribute to use') # The active point vector name. point_vectors_name = DEnum(values_name='_point_vectors_list', desc='vectors point data attribute to use') # The active point tensor name. point_tensors_name = DEnum(values_name='_point_tensors_list', desc='tensor point data attribute to use') # The active cell scalar name. cell_scalars_name = DEnum(values_name='_cell_scalars_list', desc='scalar cell data attribute to use') # The active cell vector name. cell_vectors_name = DEnum(values_name='_cell_vectors_list', desc='vectors cell data attribute to use') # The active cell tensor name. cell_tensors_name = DEnum(values_name='_cell_tensors_list', desc='tensor cell data attribute to use') ######################################## # Our view. view = View( Group( Item(name='point_scalars_name'), Item(name='point_vectors_name'), Item(name='point_tensors_name'), Item(name='cell_scalars_name'), Item(name='cell_vectors_name'), Item(name='cell_tensors_name'), )) ######################################## # Private traits. # These private traits store the list of available data # attributes. The non-private traits use these lists internally. _point_scalars_list = List(Str) _point_vectors_list = List(Str) _point_tensors_list = List(Str) _cell_scalars_list = List(Str) _cell_vectors_list = List(Str) _cell_tensors_list = List(Str) # This filter allows us to change the attributes of the data # object and will ensure that the pipeline is properly taken care # of. Directly setting the array in the VTK object will not do # this. _assign_attribute = Instance(tvtk.AssignAttribute, args=(), allow_none=False) # Toggles if this is the first time this object has been used. _first = Bool(True) ###################################################################### # `object` interface ###################################################################### def __get_pure_state__(self): d = super(SetActiveAttribute, self).__get_pure_state__() for name in ('_assign_attribute', '_first'): d.pop(name, None) # Pickle the 'point_scalars_name' etc. since these are # properties and not in __dict__. attr = {} for name in ('point_scalars', 'point_vectors', 'point_tensors', 'cell_scalars', 'cell_vectors', 'cell_tensors'): d.pop('_' + name + '_list', None) d.pop('_' + name + '_name', None) x = name + '_name' attr[x] = getattr(self, x) d.update(attr) return d ###################################################################### # `Filter` interface. ###################################################################### def update_data(self): self.data_changed = True def update_pipeline(self): if len(self.inputs) == 0 or len(self.inputs[0].outputs) == 0: return aa = self._assign_attribute aa.input = self.inputs[0].outputs[0] self._update() self._set_outputs([aa.output]) ###################################################################### # Non-public interface. ###################################################################### def _update(self): """Updates the traits for the fields that are available in the input data. """ if len(self.inputs) == 0 or len(self.inputs[0].outputs) == 0: return input = self.inputs[0].outputs[0] if self._first: # Force all attributes to be defined and computed input.update() pnt_attr, cell_attr = get_all_attributes(input) self._setup_data_traits(cell_attr, 'cell') self._setup_data_traits(pnt_attr, 'point') if self._first: self._first = False def _setup_data_traits(self, attributes, d_type): """Given the dict of the attributes from the `get_all_attributes` function and the data type (point/cell) data this will setup the object and the data. """ attrs = ['scalars', 'vectors', 'tensors'] aa = self._assign_attribute input = self.inputs[0].outputs[0] data = getattr(input, '%s_data' % d_type) for attr in attrs: values = attributes[attr] values.append('') setattr(self, '_%s_%s_list' % (d_type, attr), values) if len(values) > 1: default = getattr(self, '%s_%s_name' % (d_type, attr)) if self._first and len(default) == 0: default = values[0] getattr(data, 'set_active_%s' % attr)(default) aa.assign(default, attr.upper(), d_type.upper() + '_DATA') aa.update() kw = { '%s_%s_name' % (d_type, attr): default, 'trait_change_notify': False } self.set(**kw) def _set_data_name(self, data_type, attr_type, value): if value is None or len(self.inputs) == 0: return input = self.inputs[0].outputs[0] if len(value) == 0: # If the value is empty then we deactivate that attribute. d = getattr(input, attr_type + '_data') method = getattr(d, 'set_active_%s' % data_type) method(None) self.data_changed = True return aa = self._assign_attribute data = None if attr_type == 'point': data = input.point_data elif attr_type == 'cell': data = input.cell_data method = getattr(data, 'set_active_%s' % data_type) method(value) aa.assign(value, data_type.upper(), attr_type.upper() + '_DATA') aa.update() # Fire an event, so the changes propagate. self.data_changed = True def _point_scalars_name_changed(self, value): self._set_data_name('scalars', 'point', value) def _point_vectors_name_changed(self, value): self._set_data_name('vectors', 'point', value) def _point_tensors_name_changed(self, value): self._set_data_name('tensors', 'point', value) def _cell_scalars_name_changed(self, value): self._set_data_name('scalars', 'cell', value) def _cell_vectors_name_changed(self, value): self._set_data_name('vectors', 'cell', value) def _cell_tensors_name_changed(self, value): self._set_data_name('tensors', 'cell', value)
class ExampleWorkbenchWindow(WorkbenchWindow): """A simple example of using the workbench window.""" #### 'WorkbenchWindow' interface ########################################## # The available perspectives. perspectives = [ Perspective( name = 'Foo', contents = [ PerspectiveItem(id='Black', position='bottom'), PerspectiveItem(id='Debug', position='left') ] ), Perspective( name = 'Bar', contents = [ PerspectiveItem(id='Debug', position='left') ] ) ] #### Private interface #################################################### # The Exit action. _exit_action = Instance(Action) # The New Person action. _new_person_action = Instance(Action) ########################################################################### # 'ApplicationWindow' interface. ########################################################################### def _editor_manager_default(self): """ Trait initializer. Here we return the replacement editor manager. """ return ExampleEditorManager() def _menu_bar_manager_default(self): """Trait initializer.""" file_menu = MenuManager(self._new_person_action, self._exit_action, name='&File', id='FileMenu') view_menu = ViewMenuManager(name='&View', id='ViewMenu', window=self) user_menu = UserMenuManager(id='UserMenu', window=self) return MenuBarManager(file_menu, view_menu, user_menu, window=self) def _tool_bar_manager_default(self): """Trait initializer.""" return ToolBarManager(self._exit_action, show_tool_names=False) ########################################################################### # 'WorkbenchWindow' interface. ########################################################################### def _views_default(self): """Trait initializer.""" from secured_debug_view import SecuredDebugView return [SecuredDebugView(window=self)] ########################################################################### # Private interface. ########################################################################### def __exit_action_default(self): """Trait initializer.""" return Action(name='E&xit', on_perform=self.workbench.exit) def __new_person_action_default(self): """Trait initializer.""" # Create the action and secure it with the appropriate permission. act = Action(name='New Person', on_perform=self._new_person) act = SecureProxy(act, permissions=[NewPersonPerm]) return act def _new_person(self): """Create a new person.""" self.workbench.edit(Person(name='New', age=100))
class ImageReader(FileDataSource): """A Image file reader. The reader supports all the different types of Image files. """ # The version of this class. Used for persistence. __version__ = 0 # The Image data file reader. reader = Instance(tvtk.Object, allow_none=False, record=True) # Information about what this object can produce. output_info = PipelineInfo(datasets=['image_data']) # Our view. view = View(Group(Include('time_step_group'), Item(name='base_file_name'), Item(name='reader', style='custom', resizable=True), show_labels=False), resizable=True) ###################################################################### # Private Traits _image_reader_dict = Dict(Str, Instance(tvtk.Object)) ###################################################################### # `object` interface ###################################################################### def __init__(self, **traits): d = { 'bmp': tvtk.BMPReader(), 'jpg': tvtk.JPEGReader(), 'png': tvtk.PNGReader(), 'pnm': tvtk.PNMReader(), 'dcm': tvtk.DICOMImageReader(), 'tiff': tvtk.TIFFReader(), 'ximg': tvtk.GESignaReader(), 'dem': tvtk.DEMReader(), 'mha': tvtk.MetaImageReader(), 'mhd': tvtk.MetaImageReader(), } # Account for pre 5.2 VTk versions, without MINC reader if hasattr(tvtk, 'MINCImageReader'): d['mnc'] = tvtk.MINCImageReader() d['jpeg'] = d['jpg'] self._image_reader_dict = d # Call parent class' init. super(ImageReader, self).__init__(**traits) def __set_pure_state__(self, state): # The reader has its own file_name which needs to be fixed. state.reader.file_name = state.file_path.abs_pth # Now call the parent class to setup everything. super(ImageReader, self).__set_pure_state__(state) ###################################################################### # `FileDataSource` interface ###################################################################### def update(self): self.reader.update() if len(self.file_path.get()) == 0: return self.render() ###################################################################### # Non-public interface ###################################################################### def _file_path_changed(self, fpath): value = fpath.get() if len(value) == 0: return # Extract the file extension splitname = value.strip().split('.') extension = splitname[-1].lower() # Select image reader based on file type old_reader = self.reader if self._image_reader_dict.has_key(extension): self.reader = self._image_reader_dict[extension] else: self.reader = tvtk.ImageReader() self.reader.file_name = value.strip() self.reader.update() self.reader.update_information() if old_reader is not None: old_reader.on_trait_change(self.render, remove=True) self.reader.on_trait_change(self.render) self.outputs = [self.reader.output] # Change our name on the tree view self.name = self._get_name() def _get_name(self): """ Returns the name to display on the tree view. Note that this is not a property getter. """ fname = basename(self.file_path.get()) ret = "%s" % fname if len(self.file_list) > 1: ret += " (timeseries)" if '[Hidden]' in self.name: ret += ' [Hidden]' return ret
class Circle(HasTraits): center = Instance(Point) x = DelegatesTo("center") y = PrototypedFrom("center") r = Int
class MyModel(HasTraits): select = Range(0, len(data) - 1, 0) last_select = deepcopy(select) iso_value = Range(iso_min, iso_max, iso_val, mode='logslider') opacity = Range(0, 1.0, 1.0) show_atoms = Bool(True) label = Str() available = List(Str) available = datalabels prev_button = Button('Previous') next_button = Button('Next') scene = Instance(MlabSceneModel, ()) plot_atoms = Instance(PipelineBase) plot0 = Instance(PipelineBase) # When the scene is activated, or when the parameters are changed, we # update the plot. @on_trait_change( 'select,iso_value,show_atoms,opacity,label,scene.activated') def update_plot(self): if self.plot0 is None: if not is_vectorfield: src = mlab.pipeline.scalar_field(X, Y, Z, data[self.select]) self.plot0 = self.scene.mlab.pipeline.iso_surface( src, contours=[-self.iso_value, self.iso_value], opacity=self.opacity, colormap='blue-red', vmin=-1e-8, vmax=1e-8) else: self.plot0 = self.scene.mlab.quiver3d( X, Y, Z, *data[self.select]) #flow self.plot0.scene.background = (1, 1, 1) elif self.select != self.last_select: if not is_vectorfield: self.plot0.mlab_source.set(scalars=data[self.select]) else: self.plot0.mlab_source.set( vectors=data[self.select].reshape((3, -1)).T) if not is_vectorfield: self.plot0.contour.contours = [-self.iso_value, self.iso_value] self.plot0.actor.property.opacity = self.opacity self.last_select = deepcopy(self.select) if datalabels is not None: self.label = datalabels[self.select] if geo_spec is not None: if self.plot_atoms is None: self.plot_atoms = self.scene.mlab.points3d( geo_spec[:, 0], geo_spec[:, 1], geo_spec[:, 2], scale_factor=0.75, resolution=20) self.plot_atoms.visible = self.show_atoms def _prev_button_fired(self): if self.select > 0: self.select -= 1 def _next_button_fired(self): if self.select < len(data) - 1: self.select += 1 # The layout of the dialog created items = (Item('scene', editor=SceneEditor(scene_class=MayaviScene), height=400, width=600, show_label=False), ) items0 = () if len(data) > 1: items0 += (Group( 'select', HSplit(Item('prev_button', show_label=False), Item('next_button', show_label=False))), ) items0 += (Group('iso_value', 'opacity', 'show_atoms'), ) if datalabels is not None: if len(datalabels) > 1: items1 = (Item('available', editor=ListStrEditor(title='Available Data', editable=False), show_label=False, style='readonly', width=300), ) items0 = HSplit(items0, items1) items += ( Group( Item('label', label='Selected Data', style='readonly', show_label=True), '_'), items0, ) else: items += items0 view = View(VSplit(items[0], items[1:]), resizable=True)
class BoundaryMarkerEditor(Filter): """ Edit the boundary marker of a Triangle surface mesh. To use: select the label to assign, hover your cursor over the cell you wish to edit, and press 'p'. """ # The version of this class. Used for persistence. __version__ = 0 _current_grid = Instance(tvtk.UnstructuredGrid, allow_none=False) _input_grid = Instance(tvtk.UnstructuredGrid, args=(), allow_none=False) _extract_cells_filter = Instance(tvtk.ExtractCells, args=(), allow_none=False) _dataset_manager = Instance(DatasetManager, allow_none=False) _cell_mappings = List label_to_apply = Range(0, 255) select_coplanar_cells = Bool epsilon = Range(0.0, 1.0, 0.0001) mask_labels = Bool labels_to_mask = List(label_to_apply) # Saving file output_file = File save = Button ###################################################################### # The view. ###################################################################### traits_view = \ View( Group( Item(name='label_to_apply'), Item(name='select_coplanar_cells'), Item(name='epsilon', enabled_when='select_coplanar_cells', label='Tolerance'), Item(name='mask_labels'), Group( Item(name='labels_to_mask', style='custom', editor=ListEditor(rows=3)), show_labels=False, show_border=True, label='Labels to mask', enabled_when='mask_labels==True' ), Group( Item(name='output_file'), Item(name='save', label='Save'), show_labels=False, show_border=True, label='Save changes to file (give only a basename, without the file extension)' ) ), height=500, width=600 ) ###################################################################### # `Filter` interface. ###################################################################### def update_pipeline(self): if len(self.inputs) == 0 or len(self.inputs[0].outputs) == 0: return # Call cell_picked() when a cell is clicked. self.scene.picker.cellpicker.add_observer("EndPickEvent", self.cell_picked) self.scene.picker.pick_type = 'cell_picker' self.scene.picker.tolerance = 0.0 self.scene.picker.show_gui = False self._input_grid.deep_copy(self.inputs[0].outputs[0]) self._current_grid = self._input_grid self._dataset_manager = DatasetManager(dataset=self._input_grid) self._set_outputs([self._current_grid]) # Filter for masking. self._extract_cells_filter.set_input(self._input_grid) def update_data(self): self.data_changed = True ###################################################################### # Non-public interface. ###################################################################### def cell_picked(self, object, event): cell_id = self.scene.picker.cellpicker.cell_id self.modify_cell(cell_id, self.label_to_apply) if (self.select_coplanar_cells): self.modify_neighbouring_cells(cell_id) if (self.mask_labels): self.perform_mask() self._dataset_manager.activate(self._input_grid.cell_data.scalars.name, 'cell') self._dataset_manager.update() self.pipeline_changed = True def get_all_cell_neigbours(self, cell_id, cell): neighbour_cell_ids = array([], dtype=int) for i in range(cell.number_of_edges): # Get points belonging to ith edge edge_point_ids = cell.get_edge(i).point_ids # Find neigbours which share the edge current_neighbour_cell_ids = tvtk.IdList() self._current_grid.get_cell_neighbors(cell_id, edge_point_ids, current_neighbour_cell_ids) neighbour_cell_ids = append(neighbour_cell_ids, array(current_neighbour_cell_ids)) return neighbour_cell_ids.tolist() def modify_neighbouring_cells(self, cell_id): cell = self._current_grid.get_cell(cell_id) cell_normal = [0, 0, 0] cell.compute_normal(cell.points[0], cell.points[1], cell.points[2], cell_normal) cells_pending = self.get_all_cell_neigbours(cell_id, cell) cells_visited = [cell_id] while (len(cells_pending) > 0): current_cell_id = cells_pending.pop() if (current_cell_id not in cells_visited): cells_visited.append(current_cell_id) current_cell = self._current_grid.get_cell(current_cell_id) current_cell_normal = [0, 0, 0] current_cell.compute_normal(current_cell.points[0], current_cell.points[1], current_cell.points[2], current_cell_normal) if (dot(cell_normal, current_cell_normal) > (1 - self.epsilon)): self.modify_cell(current_cell_id, self.label_to_apply) cells_pending.extend( self.get_all_cell_neigbours(current_cell_id, current_cell)) def _mask_labels_changed(self): if (self.mask_labels): self.perform_mask() self._current_grid = self._extract_cells_filter.output else: self._current_grid = self._input_grid self._set_outputs([self._current_grid]) self.pipeline_changed = True def _labels_to_mask_changed(self): self.perform_mask() def _labels_to_mask_items_changed(self): self.perform_mask() def perform_mask(self): labels_array = self._input_grid.cell_data.get_array( self._input_grid.cell_data.scalars.name) in_masked = map(lambda x: x in self.labels_to_mask, labels_array) unmasked_cells_list = tvtk.IdList() cell_ids = range(self._input_grid.number_of_cells) # _cell_mappings is indexed by cell_id of the original input grid, and each value # is the new cell_id of the corresponding cell in the masked grid self._cell_mappings = map( lambda masked, cell_id: None if masked else unmasked_cells_list.insert_next_id(cell_id), in_masked, cell_ids) self._extract_cells_filter.set_cell_list(unmasked_cells_list) self._extract_cells_filter.update() self.pipeline_changed = True def modify_cell(self, cell_id, value): if (self.mask_labels): cell_id = self._cell_mappings.index( cell_id) # Adjust cell_id if masked self._input_grid.cell_data.get_array( self._input_grid.cell_data.scalars.name)[cell_id] = value def _save_fired(self): from mayavi_amcg.triangle_writer import TriangleWriter if (self.output_file): writer = TriangleWriter(self._input_grid, self.output_file) writer.write() print "#### Saved ####"
class Mass(HasTraits): lunit = Tuple((1.0, 'm')) # meters munit = Tuple((1.0, 'kg')) # kg tunit = Tuple((1.0, 's')) # seconds g = Float(9.81) # lunit / tunit**2 rho = Float(1.225) # munit / lunit**3 objects = List(Instance(MassObject)) def write_mass_file(self, file): file.write('Lunit = %d %s\n' %(self.lunit[0], self.lunit[1])) file.write('Munit = %d %s\n' %(self.munit[0], self.munit[1])) file.write('Tunit = %d %s\n' %(self.tunit[0], self.tunit[1])) file.write('g = %d\n' %self.g) file.write('rho = %d\n' %self.rho) file.write('\n') file.write('# mass x y z Ixx Iyy Izz [ Ixy Ixz Iyz ]') for massobj in self.objects: file.write('%f\t' %(massobj.mass)) file.write('%f %f %f\t' %(tuple(massobj.cg))) file.write('%f %f %f\t' %(tuple(massobj.inertia_moment))) file.write('%f %f %f\n' %(tuple(massobj.cross_inertia))) @classmethod def mass_from_file(cls, filename): file = open(filename) lines = file.readlines() file.close() lines = filter_lines(lines) traits = {} for i,line in enumerate(lines): match = re.match(r'(?P<name>\S+?)\s*?=\s*?(?P<value>[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)\s*?(?P<unit>\S+)?$', line) if match is not None: name = match.group('name').lower() if name[1:] == 'unit': unit = match.group('unit') unit = unit if unit is not None else '' value = float(match.group('value')) traits[name] = value, unit else: value = float(match.group('value')) traits[name] = value else: break multiplier = MassObject(mass=1.0, cg=numpy.ones(3), inertia_moment=numpy.ones(3), cross_inertia=numpy.ones(3)) adder = MassObject(mass=0.0, cg=numpy.zeros(3), inertia_moment=numpy.zeros(3), cross_inertia=numpy.zeros(3)) traits['objects'] = [] for line in lines[i:]: vals = [float(num) for num in line.split()] cross_inertia = numpy.zeros(3) if len(vals) > 7: cross_inertia[:] = vals[7:] mass=vals[0] cg=vals[1:4] inertia_moment=vals[4:7] massobj = MassObject(mass=multiplier.mass*mass+adder.mass, cg=multiplier.cg * cg + adder.cg, inertia_moment=multiplier.inertia_moment * inertia_moment + adder.inertia_moment, cross_inertia=multiplier.cross_inertia * cross_inertia + adder.cross_inertia, ) traits['objects'].append(massobj) return Mass(**traits)
class Slot(Variable): """A trait for an object of a particular type or implementing a particular interface. Both Traits Interfaces and zope.interface.Interfaces are supported. """ def __init__(self, klass = object, allow_none = True, factory = None, args = None, kw = None, **metadata): try: iszopeiface = issubclass(klass, zope.interface.Interface) except TypeError: iszopeiface = False metadata.setdefault( 'copy', 'deep' ) self._allow_none = allow_none self.klass = klass default_value = None if has_interface(klass, IContainer) or (isclass(klass) and IContainer.implementedBy(klass)): self._is_container = True else: self._is_container = False if iszopeiface: self._instance = None self.factory = factory self.args = args self.kw = kw else: self._instance = Instance(klass=klass, allow_none=allow_none, factory=factory, args=args, kw=kw, **metadata) default_value = self._instance.default_value super(Slot, self).__init__(default_value, **metadata) def validate ( self, obj, name, value ): if value is None: if self._allow_none: return value self.validate_failed( obj, name, value ) if self._instance is None: # our iface is a zope.interface if not self.klass.providedBy(value): self._iface_error(obj, name, self.klass.__name__) else: try: value = self._instance.validate(obj, name, value) except Exception: if issubclass(self._instance.klass, Interface): self._iface_error(obj, name, self._instance.klass.__name__) else: obj.raise_exception("%s must be an instance of class '%s'" % (name, self._instance.klass.__name__), TypeError) return value def post_setattr ( self, obj, name, value ): # Containers must know their place within the hierarchy, so set their # parent here. This keeps side effects out of validate() if self._is_container and value is not None: if value.parent is not obj: value.parent = obj # VariableTrees also need to know their iotype if hasattr(value, '_iotype'): value._iotype = self.iotype def _iface_error(self, obj, name, iface_name): obj.raise_exception("%s must provide interface '%s'" % (name, iface_name), TypeError)
class Actor2D(Component): # The version of this class. Used for persistence. __version__ = 0 # The mapper. mapper = Instance(tvtk.AbstractMapper, record=True) # The actor. actor = Instance(tvtk.Prop, record=True) # The actor's property. property = Instance(tvtk.Property2D, record=True) ######################################## # View related traits. # The Actor's view group. _actor_group = Group(Item(name='visibility'), Item(name='height'), Item(name='width'), show_border=True, label='Actor') # The View for this object. view = View(Group(Item(name='actor', style='custom', editor=InstanceEditor(view=View(_actor_group))), show_labels=False, label='Actor' ), Group(Item(name='mapper', style='custom', resizable=True), show_labels=False, label='Mapper'), Group(Item(name='property', style='custom', resizable=True), show_labels=False, label='Property'), resizable=True, ) ###################################################################### # `Component` interface ###################################################################### def setup_pipeline(self): """Override this method so that it *creates* its tvtk pipeline. This method is invoked when the object is initialized via `__init__`. Note that at the time this method is called, the tvtk data pipeline will *not* yet be setup. So upstream data will not be available. The idea is that you simply create the basic objects and setup those parts of the pipeline not dependent on upstream sources and filters. """ if self.mapper is None: self.mapper = tvtk.TextMapper() self.actor = tvtk.Actor2D() self.property = self.actor.property def update_pipeline(self): """Override this method so that it *updates* the tvtk pipeline when data upstream is known to have changed. This method is invoked (automatically) when the input fires a `pipeline_changed` event. """ if (len(self.inputs) == 0) or \ (len(self.inputs[0].outputs) == 0): return self.mapper.input = self.inputs[0].outputs[0] self.render() def update_data(self): """Override this method to do what is necessary when upstream data changes. This method is invoked (automatically) when any of the inputs sends a `data_changed` event. """ # Invoke render to update any changes. self.render() ###################################################################### # Non-public interface. ###################################################################### def _setup_handlers(self, old, new): if old is not None: old.on_trait_change(self.render, remove=True) new.on_trait_change(self.render) def _mapper_changed(self, old, new): # Setup the handlers. self._setup_handlers(old, new) # Setup the inputs to the mapper. if (len(self.inputs) > 0) and (len(self.inputs[0].outputs) > 0): new.input = self.inputs[0].outputs[0] # Setup the actor's mapper. actor = self.actor if actor is not None: actor.mapper = new self.render() def _actor_changed(self, old, new): # Setup the handlers. self._setup_handlers(old, new) # Set the mapper. mapper = self.mapper if mapper is not None: new.mapper = mapper # Set the property. prop = self.property if prop is not None: new.property = prop # Setup the `actors` trait. self.actors = [new] def _property_changed(self, old, new): # Setup the handlers. self._setup_handlers(old, new) # Setup the actor. actor = self.actor if new is not actor.property: actor.property = new def _foreground_changed_for_scene(self, old, new): # Change the default color for the actor. self.property.color = new self.render() def _scene_changed(self, old, new): super(Actor2D, self)._scene_changed(old, new) self._foreground_changed_for_scene(None, new.foreground)
class TestObject(HasTraits): scene1 = Instance(MlabSceneModel, ()) scene2 = Instance(MlabSceneModel, ())
class FitsSource(Source): """A simple source that allows one to view a suitably shaped numpy array as ImageData. This supports both scalar and vector data. """ # The scalar array data we manage. scalar_data = Trait(None, _check_scalar_array, rich_compare=False) # The name of our scalar array. scalar_name = Str('scalar') # The spacing of the points in the array. spacing = Array(dtype=float, shape=(3,), value=(1.0, 1.0, 1.0), desc='the spacing between points in array') # The origin of the points in the array. origin = Array(dtype=float, shape=(3,), value=(0.0, 0.0, 0.0), desc='the origin of the points in array') # Fire an event to update the spacing and origin - this reflushes # the pipeline. update_image_data = Button('Update spacing and origin') # The image data stored by this instance. image_data = Instance(tvtk.ImageData, allow_none=False) # Should we transpose the input data or not. Transposing is # necessary to make the numpy array compatible with the way VTK # needs it. However, transposing numpy arrays makes them # non-contiguous where the data is copied by VTK. Thus, when the # user explicitly requests that transpose_input_array is false # then we assume that the array has already been suitably # formatted by the user. transpose_input_array = Bool(True, desc='if input array should be transposed (if on VTK will copy the input data)') # Information about what this object can produce. output_info = PipelineInfo(datasets=['image_data']) # Our view. view = View(Group(Item(name='transpose_input_array'), Item(name='scalar_name'), Item(name='spacing'), Item(name='origin'), Item(name='update_image_data', show_label=False), show_labels=True) ) ###################################################################### # `object` interface. ###################################################################### def __init__(self, **traits): # Set the scalar and vector data at the end so we pop it here. sd = traits.pop('scalar_data', None) # Now set the other traits. super(FitsSource, self).__init__(**traits) # And finally set the scalar and vector data. if sd is not None: self.scalar_data = sd # Setup the mayavi pipeline by sticking the image data into # our outputs. self.outputs = [self.image_data] def __get_pure_state__(self): d = super(FitsSource, self).__get_pure_state__() d.pop('image_data', None) return d ###################################################################### # ArraySource interface. ###################################################################### def update(self): """Call this function when you change the array data in-place.""" d = self.image_data d.modified() pd = d.point_data if self.scalar_data is not None: pd.scalars.modified() self.data_changed = True ###################################################################### # Non-public interface. ###################################################################### def _image_data_default(self): s = tuple(self.spacing) o = tuple(self.origin) return tvtk.ImageData(spacing=s, origin=o) def _image_data_changed(self, value): self.outputs = [value] def _update_image_data_fired(self): sp = tuple(self.spacing) o = tuple(self.origin) self.image_data = tvtk.ImageData(spacing=sp, origin=o) sd = self.scalar_data if sd is not None: self._scalar_data_changed(sd) def _scalar_data_changed(self, data): img_data = self.image_data if data is None: img_data.point_data.scalars = None self.data_changed = True return dims = list(data.shape) if len(dims) == 2: dims.append(1) img_data.origin = tuple(self.origin) img_data.dimensions = tuple(dims) img_data.extent = 0, dims[0]-1, 0, dims[1]-1, 0, dims[2]-1 img_data.update_extent = 0, dims[0]-1, 0, dims[1]-1, 0, dims[2]-1 if self.transpose_input_array: img_data.point_data.scalars = numpy.ravel(numpy.transpose(data)) else: img_data.point_data.scalars = numpy.ravel(data) img_data.point_data.scalars.name = self.scalar_name # This is very important and if not done can lead to a segfault! typecode = data.dtype img_data.scalar_type = array_handler.get_vtk_array_type(typecode) img_data.update() # This sets up the extents correctly. img_data.update_traits() # Now flush the mayavi pipeline. self.data_changed = True def _scalar_name_changed(self, value): if self.scalar_data is not None: self.image_data.point_data.scalars.name = value self.data_changed = True def _transpose_input_array_changed(self, value): if self.scalar_data is not None: self._scalar_data_changed(self.scalar_data)
class DatasetManager(HasTraits): # The TVTK dataset we manage. dataset = Instance(tvtk.DataSet) # Our output, this is the dataset modified by us with different # active arrays. output = Property(Instance(tvtk.DataSet)) # The point scalars for the dataset. You may manipulate the arrays # in-place. However adding new keys in this dict will not set the # data in the `dataset` for that you must explicitly call # `add_array`. point_scalars = Dict(Str, Array) # Point vectors. point_vectors = Dict(Str, Array) # Point tensors. point_tensors = Dict(Str, Array) # The cell scalars for the dataset. cell_scalars = Dict(Str, Array) cell_vectors = Dict(Str, Array) cell_tensors = Dict(Str, Array) # This filter allows us to change the attributes of the data # object and will ensure that the pipeline is properly taken care # of. Directly setting the array in the VTK object will not do # this. _assign_attribute = Instance(tvtk.AssignAttribute, args=(), allow_none=False) ###################################################################### # Public interface. ###################################################################### def add_array(self, array, name, category='point'): """ Add an array to the dataset to specified category ('point' or 'cell'). """ assert len(array.shape) <= 2, "Only 2D arrays can be added." data = getattr(self.dataset, '%s_data'%category) if len(array.shape) == 2: assert array.shape[1] in [1, 3, 4, 9], \ "Only Nxm arrays where (m in [1,3,4,9]) are supported" va = tvtk.to_tvtk(array2vtk(array)) va.name = name data.add_array(va) mapping = {1:'scalars', 3: 'vectors', 4: 'scalars', 9: 'tensors'} dict = getattr(self, '%s_%s'%(category, mapping[array.shape[1]])) dict[name] = array else: va = tvtk.to_tvtk(array2vtk(array)) va.name = name data.add_array(va) dict = getattr(self, '%s_scalars'%(category)) dict[name] = array def remove_array(self, name, category='point'): """Remove an array by its name and optional category (point and cell). Returns the removed array. """ type = self._find_array(name, category) data = getattr(self.dataset, '%s_data'%category) data.remove_array(name) d = getattr(self, '%s_%s'%(category, type)) return d.pop(name) def rename_array(self, name1, name2, category='point'): """Rename a particular array from `name1` to `name2`. """ type = self._find_array(name1, category) data = getattr(self.dataset, '%s_data'%category) arr = data.get_array(name1) arr.name = name2 d = getattr(self, '%s_%s'%(category, type)) d[name2] = d.pop(name1) def activate(self, name, category='point'): """Make the specified array the active one. """ type = self._find_array(name, category) self._activate_data_array(type, category, name) def update(self): """Update the dataset when the arrays are changed. """ self.dataset.modified() self._assign_attribute.update() ###################################################################### # Non-public interface. ###################################################################### def _dataset_changed(self, value): self._setup_data() self._assign_attribute.input = value def _get_output(self): return self._assign_attribute.output def _setup_data(self): """Updates the arrays from what is available in the input data. """ input = self.dataset pnt_attr, cell_attr = get_all_attributes(input) self._setup_data_arrays(cell_attr, 'cell') self._setup_data_arrays(pnt_attr, 'point') def _setup_data_arrays(self, attributes, d_type): """Given the dict of the attributes from the `get_all_attributes` function and the data type (point/cell) data this will setup the object and the data. """ attrs = ['scalars', 'vectors', 'tensors'] aa = self._assign_attribute input = self.dataset data = getattr(input, '%s_data'%d_type) for attr in attrs: values = attributes[attr] # Get the arrays from VTK, create numpy arrays and setup our # traits. arrays = {} for name in values: va = data.get_array(name) npa = va.to_array() # Now test if changes to the numpy array are reflected # in the VTK array, if they are we are set, else we # have to set the VTK array back to the numpy array. if len(npa.shape) > 1: old = npa[0,0] npa[0][0] = old - 1 if abs(va[0][0] - npa[0,0]) > 1e-8: va.from_array(npa) npa[0][0] = old else: old = npa[0] npa[0] = old - 1 if abs(va[0] - npa[0]) > 1e-8: va.from_array(npa) npa[0] = old arrays[name] = npa setattr(self, '%s_%s'%(d_type, attr), arrays) def _activate_data_array(self, data_type, category, name): """Activate (or deactivate) a particular array. Given the nature of the data (scalars, vectors etc.) and the type of data (cell or points) it activates the array given by its name. Parameters: ----------- data_type: one of 'scalars', 'vectors', 'tensors' category: one of 'cell', 'point'. name: string of array name to activate. """ input = self.dataset data = None data = getattr(input, category + '_data') method = getattr(data, 'set_active_%s'%data_type) if len(name) == 0: # If the value is empty then we deactivate that attribute. method(None) else: aa = self._assign_attribute method(name) aa.assign(name, data_type.upper(), category.upper() +'_DATA') aa.update() def _find_array(self, name, category='point'): """Return information on which kind of attribute contains the specified named array in a particular category.""" types = ['scalars', 'vectors', 'tensors'] for type in types: attr = '%s_%s'%(category, type) d = getattr(self, attr) if name in d.keys(): return type raise KeyError('No %s array named %s available in dataset' %(category, name))
class VariableMeshPannerView(HasTraits): plot = Instance(Plot) spawn_zoom = Button vm_plot = Instance(VMImagePlot) use_tools = Bool(True) full_container = Instance(HPlotContainer) container = Instance(OverlayPlotContainer) traits_view = View( Group(Item('full_container', editor=ComponentEditor(size=(512, 512)), show_label=False), Item('field', show_label=False), orientation="vertical"), width=800, height=800, resizable=True, title="Pan and Scan", ) def _vm_plot_default(self): return VMImagePlot(panner=self.panner) def __init__(self, **kwargs): super(VariableMeshPannerView, self).__init__(**kwargs) # Create the plot self.add_trait("field", DelegatesTo("vm_plot")) plot = self.vm_plot.plot img_plot = self.vm_plot.img_plot if self.use_tools: plot.tools.append(PanTool(img_plot)) zoom = ZoomTool(component=img_plot, tool_mode="box", always_on=False) plot.overlays.append(zoom) imgtool = ImageInspectorTool(img_plot) img_plot.tools.append(imgtool) overlay = ImageInspectorOverlay(component=img_plot, image_inspector=imgtool, bgcolor="white", border_visible=True) img_plot.overlays.append(overlay) image_value_range = DataRange1D(self.vm_plot.fid) cbar_index_mapper = LinearMapper(range=image_value_range) self.colorbar = ColorBar(index_mapper=cbar_index_mapper, plot=img_plot, padding_right=40, resizable='v', width=30) self.colorbar.tools.append( PanTool(self.colorbar, constrain_direction="y", constrain=True)) zoom_overlay = ZoomTool(self.colorbar, axis="index", tool_mode="range", always_on=True, drag_button="right") self.colorbar.overlays.append(zoom_overlay) # create a range selection for the colorbar range_selection = RangeSelection(component=self.colorbar) self.colorbar.tools.append(range_selection) self.colorbar.overlays.append( RangeSelectionOverlay(component=self.colorbar, border_color="white", alpha=0.8, fill_color="lightgray")) # we also want to the range selection to inform the cmap plot of # the selection, so set that up as well range_selection.listeners.append(img_plot) self.full_container = HPlotContainer(padding=30) self.container = OverlayPlotContainer(padding=0) self.full_container.add(self.colorbar) self.full_container.add(self.container) self.container.add(self.vm_plot.plot)
class EXDesignReader(HasTraits): '''Read the data from the directory The design is described in semicolon-separated csv file providing the information about design parameters. Each file has the name n.txt ''' #-------------------------------------------------------------------- # Specification of the design - factor list, relative paths, etc #-------------------------------------------------------------------- open_exdesign = Button() def _open_exdesign_fired(self): file_name = open_file(filter=['*.eds'], extensions=[FileInfo(), TextInfo()]) if file_name != '': self.exdesign_spec_file = file_name exdesign_spec_file = File def _exdesign_spec_file_changed(self): print('changed file') f = file(self.exdesign_spec_file) str = f.read() self.exdesign_spec = eval('ExDesignSpec( %s )' % str) exdesign_spec = Instance(ExDesignSpec) def _exdesign_spec_default(self): return ExDesignSpec() @on_trait_change('exdesign_spec') def _reset_design_file(self): dir = os.path.dirname(self. exdesign_spec_file) exdesign_file = self.exdesign_spec.design_file self.design_file = os.path.join(dir, exdesign_file) #-------------------------------------------------------------------- # file containing the association between the factor combinations # and data files having the data #-------------------------------------------------------------------- design_file = File def _design_file_changed(self): self.exdesign = self._read_exdesign() exdesign_table_columns = Property(List, depends_on='exdesign_spec+') @cached_property def _get_exdesign_table_columns(self): return [ObjectColumn(name=ps[2], editable=False, width=0.15) for ps in self.exdesign_spec.factors] exdesign = List(Any) def _exdesign_default(self): return self._read_exdesign() def _read_exdesign(self): ''' Read the experiment design. ''' if exists(self.design_file): reader = csv.reader(open(self.design_file, 'r'), delimiter=';') data_dir = os.path.join(os.path.dirname(self.design_file), self.exdesign_spec.data_dir) return [ExRun(self, row, data_dir=data_dir) for row in reader] else: return [] selected_exrun = Instance(ExRun) def _selected_exrun_default(self): if len(self.exdesign) > 0: return self.exdesign[0] else: return None last_exrun = Instance(ExRun) selected_exruns = List(ExRun) #------------------------------------------------------------------ # Array plotting #------------------------------------------------------------------- # List of arrays to be plotted data = Instance(AbstractPlotData) def _data_default(self): return ArrayPlotData(x=array([]), y=array([])) @on_trait_change('selected_exruns') def _rest_last_exrun(self): if len(self.selected_exruns) > 0: self.last_exrun = self.selected_exruns[-1] @on_trait_change('selected_exruns') def _reset_data(self): ''' ''' runs, xlabels, ylabels, ylabels_fitted = self._generate_data_labels() for name in list(self.plot.plots.keys()): self.plot.delplot(name) for idx, exrun in enumerate(self.selected_exruns): if xlabels[idx] not in self.plot.datasources: self.plot.datasources[xlabels[idx]] = ArrayDataSource(exrun.xdata, sort_order='none') if ylabels[idx] not in self.plot.datasources: self.plot.datasources[ylabels[idx]] = ArrayDataSource(exrun.ydata, sort_order='none') if ylabels_fitted[idx] not in self.plot.datasources: self.plot.datasources[ylabels_fitted[idx]] = ArrayDataSource(exrun.polyfit, sort_order='none') for run, xlabel, ylabel, ylabel_fitted in zip(runs, xlabels, ylabels, ylabels_fitted): self.plot.plot((xlabel, ylabel), color='brown') self.plot.plot((xlabel, ylabel_fitted), color='blue') def _generate_data_labels(self): ''' Generate the labels consisting of the axis and run-number. ''' return ([e.std_num for e in self.selected_exruns], ['x-%d' % e.std_num for e in self.selected_exruns], ['y-%d' % e.std_num for e in self.selected_exruns], ['y-%d-fitted' % e.std_num for e in self.selected_exruns]) plot = Instance(Plot) def _plot_default(self): p = Plot() p.tools.append(PanTool(p)) p.overlays.append(ZoomTool(p)) return p view_traits = View(HSplit(VGroup(Item('open_exdesign', style='simple'), Item('exdesign', editor=exrun_table_editor, show_label=False, style='custom') ), VGroup(Item('last_exrun@', show_label=False), Item('plot', editor=ComponentEditor(), show_label=False, resizable=True ), ), ), # handler = EXDesignReaderHandler(), resizable=True, buttons=[OKButton, CancelButton], height=1., width=1.)
class ActionItem(ActionManagerItem): """ An action manager item that represents an actual action. """ #### 'ActionManagerItem' interface ######################################## # The item's unique identifier ('unique' in this case means unique within # its group). id = Property(Str) #### 'ActionItem' interface ############################################### # The action! action = Instance(Action) # The toolkit specific control created for this item. control = Any # The toolkit specific Id of the control created for this item. # # We have to keep the Id as well as the control because wx tool bar tools # are created as 'wxObjectPtr's which do not have Ids, and the Id is # required to manipulate the state of a tool via the tool bar 8^( # FIXME v3: Why is this part of the public interface? control_id = Any #### Private interface #################################################### # All of the internal instances that wrap this item. _wrappers = List(Any) ########################################################################### # 'ActionManagerItem' interface. ########################################################################### #### Trait properties ##################################################### def _get_id(self): """ Return's the item's Id. """ return self.action.id #### Trait change handlers ################################################ def _enabled_changed(self, trait_name, old, new): """ Static trait change handler. """ self.action.enabled = new return def _visible_changed(self, trait_name, old, new): """ Static trait change handler. """ self.action.visible = True return ########################################################################### # 'ActionItem' interface. ########################################################################### def add_to_menu(self, parent, menu, controller): """ Adds the item to a menu. """ if (controller is None) or controller.can_add_to_menu(self.action): wrapper = _MenuItem(parent, menu, self, controller) # fixme: Martin, who uses this information? if controller is None: self.control = wrapper.control self.control_id = wrapper.control_id self._wrappers.append(wrapper) return def add_to_toolbar(self, parent, tool_bar, image_cache, controller, show_labels=True): """ Adds the item to a tool bar. """ if (controller is None) or controller.can_add_to_toolbar(self.action): wrapper = _Tool( parent, tool_bar, image_cache, self, controller, show_labels ) # fixme: Martin, who uses this information? if controller is None: self.control = wrapper.control self.control_id = wrapper.control_id self._wrappers.append(wrapper) return def add_to_palette(self, tool_palette, image_cache, show_labels=True): """ Adds the item to a tool palette. """ wrapper = _PaletteTool(tool_palette, image_cache, self, show_labels) self._wrappers.append(wrapper) return def destroy(self): """ Called when the action is no longer required. By default this method calls 'destroy' on the action itself. """ self.action.destroy() return
class OpsViewerTask(BaseViewerTask): boundary_map = dict(constant=0, finite=1, periodic=2, reflective=3) boundary = Enum(sorted(boundary_map)) threshold = Float(0) max_nof_points = Int(100) runner_thread = Instance(RunnerThread) compute_fft_button = Button('Compute FFT') compute_ifft_button = Button('Compute IFFT') clear_fft_worker_button = Button('Clear FFT worker cache') find_local_maxima_button = Button('Find local maxima') find_local_minima_button = Button('Find local minima') clear_points_button = Button('Clear points') take_real_button = Button('Take real') take_imag_button = Button('Take imag') take_abs_button = Button('Take abs') discrete_gauss_blur_button = Button('Discrete Gauss blur') discrete_gauss_laplace_button = Button('Discrete Gauss Laplace') discrete_gauss_scales = Tuple(0.0, 0.0, 0.0) discrete_gauss_widths = Tuple(0.0, 0.0, 0.0) discrete_gauss_sizes = Tuple(0.0, 0.0, 0.0) stop_button = Button('Stop') traits_view = View( VGroup( HSplit( Item('compute_fft_button', show_label=False), Item('compute_ifft_button', show_label=False), Item('clear_fft_worker_button', show_label=False), ), '_', HSplit(Item('boundary'), Item('threshold'), Item('max_nof_points'), visible_when='is_real'), HSplit(Item('find_local_maxima_button', show_label=False), Item('find_local_minima_button', show_label=False), Item('clear_points_button', show_label=False), Item('stop_button', show_label=False), visible_when='is_real'), '_', HSplit(Item('take_real_button', show_label=False), Item('take_imag_button', show_label=False), Item('take_abs_button', show_label=False), visible_when='is_complex'), '_', Group( HSplit(Item('discrete_gauss_blur_button', show_label=False), Item('discrete_gauss_laplace_button', show_label=False)), Item( 'discrete_gauss_sizes', label='DG FWHM [um]', editor=TupleEditor(cols=3, labels=['Z', 'Y', 'X']), ), Item( 'discrete_gauss_widths', label='DG FWHM [px]', editor=TupleEditor(cols=3, labels=['Z', 'Y', 'X']), ), Item( 'discrete_gauss_scales', label='DG scales', editor=TupleEditor(cols=3, labels=['Z', 'Y', 'X']), ), ))) @property def fft_worker(self): if _fft_worker_cache[0] != self.viewer.data.shape: timeit = TimeIt(self.viewer, 'creating FFT worker') _fft_worker_cache[0] = self.viewer.data.shape _fft_worker_cache[1] = FFTTasks(self.viewer.data.shape, options=Options(fftw_threads=4)) timeit.stop() return _fft_worker_cache[1] @fft_worker.deleter def fft_worker(self): timeit = TimeIt('removing FFT worker') _fft_worker_cache[0] = None _fft_worker_cache[1] = None def _clear_fft_worker_button_fired(self): fft_worker = self.fft_worker if fft_worker is not None: fft_worker.clear() del self.fft_worker def _discrete_gauss_scales_changed(self, old, new): if old == new: return self.discrete_gauss_widths = tuple( [2.3548 * t**0.5 for t in self.discrete_gauss_scales]) def _discrete_gauss_widths_changed(self, old, new): if old == new: return self.discrete_gauss_scales = tuple([ (w / 2.3548)**2 for w in self.discrete_gauss_widths ]) self.discrete_gauss_sizes = tuple([ w * s for w, s in zip(self.discrete_gauss_widths, self.viewer.voxel_sizes) ]) def _discrete_gauss_sizes_changed(self, old, new): if old == new: return self.discrete_gauss_widths = tuple([ w / s for w, s in zip(self.discrete_gauss_sizes, self.viewer.voxel_sizes) ]) def _name_default(self): return 'Operations' def startup(self): # import fftw routines #self._fft_worker = FFTTasks(self.viewer.data.shape, options = Options(fftw_threads = 4)) pass def _clear_points_button_fired(self): self.viewer.set_point_data([]) self.viewer.reset() self.viewer.status = 'cleared point data' def _find_local_maxima_button_fired(self): data = self.viewer.data if not isinstance(data, numpy.ndarray): data = data[:] # tiffarray threshold = self.threshold boundary = self.boundary_map[self.boundary] timeit = TimeIt(self.viewer, 'computing local maxima') def cb(done, result, timeit=timeit): timeit.update('%.2f%% done, %s points sofar' % (done * 100.0, len(result))) if len(result) > 1000000: raise RuntimeError( 'too many points (more than billion), try blurring') try: l = local_maxima(data, threshold, boundary, cb) except Exception, msg: timeit.stop('failed with exception: %s' % (msg)) raise timeit.stop() l.sort(reverse=True) self.viewer.set_point_data(l[:self.max_nof_points]) self.viewer.reset()
class FileDataSource(Source): # The version of this class. Used for persistence. __version__ = 0 # The list of file names for the timeseries. file_list = List(Str, desc='a list of files belonging to a time series') # The current time step (starts with 0). This trait is a dummy # and is dynamically changed when the `file_list` trait changes. # This is done so the timestep bounds are linked to the number of # the files in the file list. timestep = Range(value=0, low='_min_timestep', high='_max_timestep', enter_set=True, auto_set=False, desc='the current time step') base_file_name = Str('', desc="the base name of the file", enter_set=True, auto_set=False, editor=FileEditor()) # A timestep view group that may be included by subclasses. time_step_group = Group( Item(name='file_path', style='readonly'), Item(name='timestep', defined_when='len(object.file_list) > 1')) ################################################## # Private traits. ################################################## # The current file name. This is not meant to be touched by the # user. file_path = Instance(FilePath, (), desc='the current file name') _min_timestep = Int(0) _max_timestep = Int(0) ###################################################################### # `object` interface ###################################################################### def __get_pure_state__(self): d = super(FileDataSource, self).__get_pure_state__() # These are obtained dynamically, so don't pickle them. for x in ['file_list', 'timestep']: d.pop(x, None) return d def __set_pure_state__(self, state): # Use the saved path to initialize the file_list and timestep. fname = state.file_path.abs_pth if not isfile(fname): msg = 'Could not find file at %s\n' % fname msg += 'Please move the file there and try again.' raise IOError, msg self.initialize(fname) # Now set the remaining state without touching the children. set_state(self, state, ignore=['children', 'file_path']) # Setup the children. handle_children_state(self.children, state.children) # Setup the children's state. set_state(self, state, first=['children'], ignore=['*']) ###################################################################### # `FileDataSource` interface ###################################################################### def initialize(self, base_file_name): """Given a single filename which may or may not be part of a time series, this initializes the list of files. This method need not be called to initialize the data. """ self.base_file_name = base_file_name ###################################################################### # Non-public interface ###################################################################### def _file_list_changed(self, value): # Change the range of the timestep suitably to reflect new list. n_files = len(self.file_list) timestep = min(self.timestep, n_files) self._max_timestep = max(n_files - 1, 0) if self.timestep == timestep: self._timestep_changed(timestep) else: self.timestep = timestep def _file_list_items_changed(self, list_event): self._file_list_changed(self.file_list) def _timestep_changed(self, value): file_list = self.file_list if len(file_list) > 0: self.file_path = FilePath(file_list[value]) else: self.file_path = FilePath('') def _base_file_name_changed(self, value): self.file_list = get_file_list(value) if len(self.file_list) == 0: self.file_list = [value] try: self.timestep = self.file_list.index(value) except ValueError: self.timestep = 0
class TreeItem(HasTraits): """ A generic base-class for items in a tree data structure. """ #### 'TreeItem' interface ################################################# # Does this item allow children? allows_children = Bool(True) # The item's children. children = List(Instance('TreeItem')) # Arbitrary data associated with the item. data = Any # Does the item have any children? has_children = Property(Bool) # The item's parent. parent = Instance('TreeItem') ########################################################################### # 'object' interface. ########################################################################### def __str__(self): """ Returns the informal string representation of the object. """ if self.data is None: s = '' else: s = str(self.data) return s ########################################################################### # 'TreeItem' interface. ########################################################################### #### Properties ########################################################### # has_children def _get_has_children(self): """ True iff the item has children. """ return len(self.children) != 0 #### Methods ############################################################## def append(self, child): """ Appends a child to this item. This removes the child from its current parent (if it has one). """ return self.insert(len(self.children), child) def insert(self, index, child): """ Inserts a child into this item at the specified index. This removes the child from its current parent (if it has one). """ if child.parent is not None: child.parent.remove(child) child.parent = self self.children.insert(index, child) return child def remove(self, child): """ Removes a child from this item. """ child.parent = None self.children.remove(child) return child def insert_before(self, before, child): """ Inserts a child into this item before the specified item. This removes the child from its current parent (if it has one). """ index = self.children.index(before) self.insert(index, child) return (index, child) def insert_after(self, after, child): """ Inserts a child into this item after the specified item. This removes the child from its current parent (if it has one). """ index = self.children.index(after) self.insert(index + 1, child) return (index, child)
class ConnectionMatrixViewer(HasTraits): tplot = Instance(Plot) plot = Instance(Component) custtool = Instance(CustomTool) colorbar = Instance(ColorBar) fro = Any to = Any data = None val = Float nodelabels = Any traits_view = View( Group(Item('plot', editor=ComponentEditor(size=(800, 600)), show_label=False), HGroup( Item('fro', label="From", style='readonly', springy=True), Item('to', label="To", style='readonly', springy=True), Item('val', label="Value", style='readonly', springy=True), ), orientation="vertical"), Item('data_name', label="Edge key"), # handler=CustomHandler(), resizable=True, title="Connection Matrix Viewer") def __init__(self, nodelabels, matdict, **traits): """ Starts a matrix inspector Parameters ---------- nodelables : list List of strings of labels for the rows of the matrix matdict : dictionary Keys are the edge type and values are NxN Numpy arrays """ super(HasTraits, self).__init__(**traits) self.add_trait('data_name', Enum(matdict.keys())) self.data_name = matdict.keys()[0] self.data = matdict self.nodelables = nodelabels self.plot = self._create_plot_component() # set trait notification on customtool self.custtool.on_trait_change(self._update_fields, "xval") self.custtool.on_trait_change(self._update_fields, "yval") def _data_name_changed(self, old, new): self.pd.set_data("imagedata", self.data[self.data_name]) #self.my_plot.set_value_selection((0, 2)) self.tplot.title = "Connection Matrix for %s" % self.data_name def _update_fields(self): # map mouse location to array index frotmp = int(round(self.custtool.yval) - 1) totmp = int(round(self.custtool.xval) - 1) # check if within range sh = self.data[self.data_name].shape # assume matrix whose shape is (# of rows, # of columns) if frotmp >= 0 and frotmp < sh[0] and totmp >= 0 and totmp < sh[1]: row = " (index: %i" % (frotmp + 1) + ")" col = " (index: %i" % (totmp + 1) + ")" self.fro = " " + str(self.nodelables[frotmp]) + row self.to = " " + str(self.nodelables[totmp]) + col self.val = self.data[self.data_name][frotmp, totmp] def _create_plot_component(self): # Create a plot data object and give it this data self.pd = ArrayPlotData() self.pd.set_data("imagedata", self.data[self.data_name]) # find dimensions xdim = self.data[self.data_name].shape[1] ydim = self.data[self.data_name].shape[0] # Create the plot self.tplot = Plot(self.pd, default_origin="top left") self.tplot.x_axis.orientation = "top" self.tplot.img_plot("imagedata", name="my_plot", xbounds=(0.5, xdim + 0.5), ybounds=(0.5, ydim + 0.5), colormap=jet) # Tweak some of the plot properties self.tplot.title = "Connection Matrix for %s" % self.data_name self.tplot.padding = 80 # Right now, some of the tools are a little invasive, and we need the # actual CMapImage object to give to them self.my_plot = self.tplot.plots["my_plot"][0] # Attach some tools to the plot self.tplot.tools.append(PanTool(self.tplot)) zoom = ZoomTool(component=self.tplot, tool_mode="box", always_on=False) self.tplot.overlays.append(zoom) # my custom tool to get the connection information self.custtool = CustomTool(self.tplot) self.tplot.tools.append(self.custtool) # Create the colorbar, handing in the appropriate range and colormap colormap = self.my_plot.color_mapper self.colorbar = ColorBar( index_mapper=LinearMapper(range=colormap.range), color_mapper=colormap, plot=self.my_plot, orientation='v', resizable='v', width=30, padding=20) self.colorbar.padding_top = self.tplot.padding_top self.colorbar.padding_bottom = self.tplot.padding_bottom # create a range selection for the colorbar self.range_selection = RangeSelection(component=self.colorbar) self.colorbar.tools.append(self.range_selection) self.colorbar.overlays.append( RangeSelectionOverlay(component=self.colorbar, border_color="white", alpha=0.8, fill_color="lightgray")) # we also want to the range selection to inform the cmap plot of # the selection, so set that up as well self.range_selection.listeners.append(self.my_plot) # Create a container to position the plot and the colorbar side-by-side container = HPlotContainer(use_backbuffer=True) container.add(self.tplot) container.add(self.colorbar) container.bgcolor = "white" return container
class ChainedWizardController(WizardController): """ A wizard controller that can be chained with others. """ #### 'ChainedWizardController' interface ################################## # The next chained wizard controller. next_controller = Instance(IWizardController) ########################################################################### # 'IWizardController' interface. ########################################################################### def get_next_page(self, page): """ Returns the next page. """ next_page = None if page in self._pages: if page is not self._pages[-1]: index = self._pages.index(page) next_page = self._pages[index + 1] else: if self.next_controller is not None: next_page = self.next_controller.get_first_page() else: if self.next_controller is not None: next_page = self.next_controller.get_next_page(page) return next_page def get_previous_page(self, page): """ Returns the previous page. """ if page in self._pages: index = self._pages.index(page) previous_page = self._pages[index - 1] else: if self.next_controller is not None: if self.next_controller.is_first_page(page): previous_page = self._pages[-1] else: previous_page = self.next_controller.get_previous_page(page) else: previous_page = None return previous_page def is_first_page(self, page): """ Is the page the first page? """ return page is self._pages[0] def is_last_page(self, page): """ Is the page the last page? """ if page in self._pages: # If page is not this controller's last page, then it cannot be # *the* last page. if not page is self._pages[-1]: is_last = False # Otherwise, it is *the* last page if this controller has no next # controller or the next controller has no pages. else: if self.next_controller is None: is_last = True else: is_last = self.next_controller.is_last_page(page) else: if self.next_controller is not None: is_last = self.next_controller.is_last_page(page) elif len(self._pages) > 0: is_last = False else: is_last = True return is_last def dispose_pages(self): """ Dispose the wizard's pages. """ for page in self._pages: page.dispose_page() if self.next_controller is not None: self.next_controller.dispose_pages() return ########################################################################### # 'ChainedWizardController' interface. ########################################################################### def _get_pages(self): """ Returns the pages in the wizard. """ pages = self._pages[:] if self.next_controller is not None: pages.extend(self.next_controller.pages) return pages def _set_pages(self, pages): """ Sets the pages in the wizard. """ self._pages = pages return ########################################################################### # Private interface. ########################################################################### def _update(self): """ Checks the completion status of the controller. """ # The entire wizard is complete when ALL pages are complete. for page in self._pages: if not page.complete: self.complete = False break else: if self.next_controller is not None: # fixme: This is a abstraction leak point, since _update is not # part of the wizard_controller interface! self.next_controller._update() self.complete = self.next_controller.complete else: self.complete = True return #### Trait event handlers ################################################# #### Static #### def _current_page_changed(self, old, new): """ Called when the current page is changed. """ if old is not None: old.on_trait_change( self._on_page_complete, 'complete',remove=True ) if new is not None: new.on_trait_change(self._on_page_complete, 'complete') if self.next_controller is not None: self.next_controller.current_page = new self._update() return def _next_controller_changed(self, old, new): """ Called when the next controller is changed. """ if old is not None: old.on_trait_change( self._on_controller_complete, 'complete', remove=True ) if new is not None: new.on_trait_change( self._on_controller_complete, 'complete' ) self._update() return #### Dynamic #### def _on_controller_complete(self, obj, trait_name, old, new): """ Called when the next controller's complete state changes. """ self._update() return def _on_page_complete(self, obj, trait_name, old, new): """ Called when the current page is complete. """ self._update() return
class ExampleScriptWindow(WorkbenchWindow): """ The ExampleScriptWindow class is a workbench window that contains example editors that demonstrate the use of the application scripting framework. """ #### Private interface #################################################### # The action that exits the application. _exit_action = Instance(Action) # The File menu. _file_menu = Instance(MenuManager) # The Label menu. _label_menu = Instance(MenuManager) # The Scripts menu. _scripts_menu = Instance(MenuManager) ########################################################################### # Private interface. ########################################################################### #### Trait initialisers ################################################### def __file_menu_default(self): """ Trait initialiser. """ return MenuManager(self._exit_action, name="&File") def __label_menu_default(self): """ Trait initialiser. """ size_group = Group(LabelIncrementSizeAction(window=self), LabelDecrementSizeAction(window=self)) normal = LabelNormalFontAction(window=self, id='normal', style='radio', checked=True) bold = LabelBoldFontAction(window=self, id='bold', style='radio') italic = LabelItalicFontAction(window=self, id='italic', style='radio') style_group = Group(normal, bold, italic, id='style') return MenuManager(size_group, style_group, name="&Label") def __scripts_menu_default(self): """ Trait initialiser. """ # ZZZ: This is temporary until we put the script into a view. get_script_manager().on_trait_event(self._on_script_updated, 'script_updated') return MenuManager(StartRecordingAction(), StopRecordingAction(), name="&Scripts") def __exit_action_default(self): """ Trait initialiser. """ return Action(name="E&xit", on_perform=self.workbench.exit) def _editor_manager_default(self): """ Trait initialiser. """ return ExampleEditorManager() def _menu_bar_manager_default(self): """ Trait initialiser. """ return MenuBarManager(self._file_menu, self._label_menu, self._scripts_menu, window=self) def _tool_bar_manager_default(self): """ Trait initialiser. """ return ToolBarManager(self._exit_action, show_tool_names=False) # ZZZ: This is temporary until we put the script into a view. def _on_script_updated(self, script_manager): script = script_manager.script if script: print script, else: print "Script empty"
class PlotOMatic(HasTraits): io_driver_list = Instance(IODriverList) variables = Instance(Variables) viewers = Instance(Viewers) selected_viewer = Instance(Viewer) handler = PlotOMaticHandler() viewer_node = TreeNode(node_for=[Viewer], auto_open=True, label='name', menu=Menu(handler.remove_viewer_action), icon_path='icons/', icon_item='plot.png') tree_editor = TreeEditor(nodes=[ TreeNode( node_for=[IODriverList], auto_open=True, children='io_drivers', label='=Input Drivers', menu=Menu(handler.refresh_tree_action, handler.add_io_driver_actions_menu), view=View(), ), TreeNode(node_for=[IODriver], auto_open=True, children='_decoders', label='name', add=[DataDecoder], menu=Menu(handler.remove_io_driver_action, handler.refresh_tree_action, handler.add_decoder_actions_menu), icon_path='icons/', icon_open='input.png', icon_group='input.png'), TreeNode(node_for=[DataDecoder], auto_open=True, children='', label='name', menu=Menu(handler.refresh_tree_action, handler.remove_decoder_action), icon_path='icons/', icon_item='decoder.png'), TreeNode(node_for=[IODriverList], auto_open=True, children='viewers', label='=Viewers', menu=Menu(handler.refresh_tree_action, handler.add_viewer_actions_menu), view=View()), viewer_node ], hide_root=True, orientation='vertical') view = View(HSplit( Item(name='io_driver_list', editor=tree_editor, resizable=True, show_label=False, width=.32), VSplit( Item(name='selected_viewer', style='custom', resizable=True, show_label=False, editor=InstanceEditor(view='view')), Item(name='variables', show_label=False, style='custom', height=.3))), menubar=MenuBar(handler.file_menu, handler.data_menu), title='Plot-o-matic', resizable=True, width=1000, height=600, handler=PlotOMaticHandler()) def __init__(self, **kwargs): HasTraits.__init__(self, **kwargs) self.viewer_node.on_select = self.click_viewer def click_viewer(self, viewer): self.selected_viewer = viewer self.viewers.select_viewer(viewer) def start(self): self.io_driver_list.start_all() self.viewers.start() def stop(self): self.viewers.stop() self.io_driver_list.stop_all() def get_config(self): config = {} config['io_drivers'] = self.io_driver_list.get_config() config['viewers'] = self.viewers.get_config() return config def set_config(self, config): if 'io_drivers' in config: self.io_driver_list.set_config(config['io_drivers']) if 'viewers' in config: self.viewers.set_config(config['viewers']) self.variables.clear()
class Slot(Variable): """A trait for an object of a particular type or implementing a particular interface. Both Traits Interfaces and zope.interface.Interfaces are supported. """ def __init__(self, klass=object, allow_none=True, factory=None, args=None, kw=None, **metadata): default_value = None try: iszopeiface = issubclass(klass, zope.interface.Interface) except TypeError: iszopeiface = False if not isclass(klass): default_value = klass klass = klass.__class__ metadata.setdefault('copy', 'deep') self._allow_none = allow_none self.klass = klass if has_interface(klass, IContainer) or (isclass(klass) and \ IContainer.implementedBy(klass)): self._is_container = True else: self._is_container = False if iszopeiface: self._instance = None self.factory = factory self.args = args self.kw = kw else: self._instance = Instance(klass=klass, allow_none=allow_none, factory=factory, args=args, kw=kw, **metadata) if default_value: self._instance.default_value = default_value else: default_value = self._instance.default_value super(Slot, self).__init__(default_value, **metadata) def validate(self, obj, name, value): ''' wrapper around Enthought validate method''' if value is None: if self._allow_none: return value self.validate_failed(obj, name, value) if self._instance is None: # our iface is a zope.interface if not self.klass.providedBy(value): self._iface_error(obj, name, self.klass.__name__) else: try: value = self._instance.validate(obj, name, value) except Exception: if issubclass(self._instance.klass, Interface): self._iface_error(obj, name, self._instance.klass.__name__) else: obj.raise_exception("%s must be an instance of class '%s'" % (name, self._instance.klass.__name__), TypeError) return value def post_setattr(self, obj, name, value): '''Containers must know their place within the hierarchy, so set their parent here. This keeps side effects out of validate()''' if self._is_container and value is not None: if value.parent is not obj: value.parent = obj # VariableTrees also need to know their iotype if hasattr(value, '_iotype'): value._iotype = self.iotype def _iface_error(self, obj, name, iface_name): obj.raise_exception("%s must provide interface '%s'" % (name, iface_name), TypeError) def get_attribute(self, name, value, trait, meta): """Return the attribute dictionary for this variable. This dict is used by the GUI to populate the edit UI. Slots also return an attribute dictionary for the slot pane. name: str Name of variable value: object The value of the variable trait: CTrait The variable's trait meta: dict Dictionary of metadata for this variable """ io_attr = {} io_attr['name'] = name io_attr['type'] = trait.trait_type.klass.__name__ io_attr['ttype'] = 'slot' slot_attr = {} slot_attr['name'] = name if value is None: slot_attr['filled'] = None elif value is []: slot_attr['filled'] = [] else: slot_attr['filled'] = type(value).__name__ slot_attr['klass'] = io_attr['type'] slot_attr['containertype'] = 'singleton' for field in meta: if field not in gui_excludes: slot_attr[field] = meta[field] return io_attr, slot_attr
class TemplateDataNames(HasPrivateTraits): #-- Public Traits ---------------------------------------------------------- # The data context to which bindings are made: context = Instance(ITemplateDataContext) # The current set of data names to be bound to the context: data_names = List(TemplateDataName) # The list of unresolved, required bindings: unresolved_data_names = Property(depends_on='data_names.resolved') # The list of optional bindings: optional_data_names = Property(depends_on='data_names.optional') # The list of unresolved optional bindings: unresolved_optional_data_names = Property( depends_on='data_names.[resolved,optional]') #-- Private Traits --------------------------------------------------------- # List of 'virtual' data names for use by table editor: virtual_data_names = List # The list of table editor columns: table_columns = Property(depends_on='data_names') # List( ObjectColumn ) #-- Traits View Definitions ------------------------------------------------ view = View( Item('virtual_data_names', show_label=False, style='custom', editor=table_editor)) #-- Property Implementations ----------------------------------------------- @cached_property def _get_unresolved_data_names(self): return [ dn for dn in self.data_names if (not dn.resolved) and (not dn.optional) ] @cached_property def _get_optional_data_names(self): return [dn for dn in self.data_names if dn.optional] @cached_property def _get_unresolved_optional_data_names(self): return [ dn for dn in self.data_names if (not dn.resolved) and dn.optional ] @cached_property def _get_table_columns(self): n = max([len(dn.items) for dn in self.data_names]) if n == 1: return std_columns + [ BindingsColumn(name='value0', label='Name', width=0.43) ] width = 0.43 / n return (std_columns + [ BindingsColumn(name='value%d' % i, index=i, label='Name %d' % (i + 1), width=width) for i in range(n) ]) #-- Trait Event Handlers --------------------------------------------------- def _context_changed(self, context): for data_name in self.data_names: data_name.context = context def _data_names_changed(self, old, new): """ Handles the list of 'data_names' being changed. """ # Make sure that all of the names are unique: new = set(new) # Update the old and new context links: self._update_contexts(old, new) # Update the list of virtual names based on the new set: dns = [VirtualDataName(data_name=dn) for dn in new] dns.sort(lambda l, r: cmp(l.description, r.description)) self.virtual_data_names = dns def _data_names_items_changed(self, event): # Update the old and new context links: old, new = event.old, event.new self._update_contexts(old, new) # Update the list of virtual names based on the old and new sets: i = event.index self.virtual_data_names[i:i + len(old)] = [ VirtualDataName(data_name=dn) for dn in new ] #-- Private Methods -------------------------------------------------------- def _update_contexts(self, old, new): """ Updates the data context for an old and new set of data names. """ for data_name in old: data_name.context = None context = self.context for data_name in new: data_name.context = context
class PlotOMaticHandler(Controller): # ------------ Menu related -------------------- exit_action = Action(name='&Exit', action='exit') save_session_action = Action(name='&Open Session', action='open_session') open_session_action = Action(name='&Save Session', action='save_session') file_menu = Menu(exit_action, Separator(), save_session_action, open_session_action, name='&File') def exit(self, uii): print 'Exit called, really should implement this' def save_session(self, uii): filename = save_file( filter= 'Plot-o-matic session (*.plot_session)|*.plot_session|All files (*)|*', file_name='my_session.plot_session') if filename != '': print "Saving session as '%s'" % filename session = uii.object.get_config() fp = open(filename, 'w') yaml.dump(session, fp, default_flow_style=False) fp.close() def open_session(self, uii): filename = open_file( filter= 'Plot-o-matic session (*.plot_session)|*.plot_session|All files (*)|*', file_name='my_session.plot_session') if filename != '': print "Opening session '%s'" % filename fp = open(filename, 'r') session = yaml.load(fp) fp.close() uii.object.set_config(session) clear_data_action = Action(name='&Clear Data', action='clear_data') save_data_action = Action(name='&Save Data Set', action='save_data') open_data_action = Action(name='&Open Data Set', action='open_data') data_menu = Menu(clear_data_action, Separator(), save_data_action, open_data_action, name='&Data') def clear_data(self, uii): uii.object.variables.clear() def save_data(self, uii): filename = save_file( filter= 'Plot-o-matic data set (*.plot_data)|*.plot_data|All files (*)|*', file_name='my_data.plot_data') if filename != '': uii.object.variables.save_data_set(filename) print "Saved data set '%s'" % filename def open_data(self, uii): filename = open_file( filter= 'Plot-o-matic data set (*.plot_data)|*.plot_data|All files (*)|*', file_name='my_data.plot_data') if filename != '': uii.object.variables.open_data_set(filename) print "Opened data set '%s'" % filename # ------------ Tree related -------------------- remove_io_driver_action = Action( name='Remove', action='handler.remove_io_driver(editor,object)') add_io_driver_actions_menu = Instance(Menu) remove_decoder_action = Action( name='Remove', action='handler.remove_decoder(editor,object)') add_decoder_actions_menu = Instance(Menu) remove_viewer_action = Action( name='Remove', action='handler.remove_viewer(editor,object)') add_viewer_actions_menu = Instance(Menu) refresh_tree_action = Action(name='Refresh', action='handler.refresh_tree(editor)') def refresh_tree(self, editor): editor.update_editor() def _add_io_driver_actions_menu_default(self): actions = [] for io_driver_plugin in find_io_driver_plugins(): actions += [ Action(name=io_driver_plugin.__name__, action='handler.add_io_driver(editor,object,"%s")' % io_driver_plugin.__name__) ] return Menu(name='Add', *actions) def remove_io_driver(self, editor, io_driver_object): io_driver_list = editor._menu_parent_object io_driver_list._remove_io_driver(io_driver_object) editor.update_editor() def add_io_driver(self, editor, io_driver_list, new_io_driver_name): new_io_driver = get_io_driver_plugin_by_name(new_io_driver_name)() io_driver_list._add_io_driver(new_io_driver) editor.update_editor() def _add_decoder_actions_menu_default(self): actions = [] for decoder_plugin in find_decoder_plugins(): actions += [ Action(name=decoder_plugin.__name__, action='handler.add_decoder(editor,object,"%s")' % decoder_plugin.__name__) ] return Menu(name='Add', *actions) def remove_decoder(self, editor, decoder_object): parent_io_driver = editor._menu_parent_object parent_io_driver._remove_decoder(decoder_object) editor.update_editor() def add_decoder(self, editor, io_driver, decoder_name): io_driver_list = editor._menu_parent_object new_decoder = get_decoder_plugin_by_name(decoder_name)() io_driver._add_decoder(new_decoder) editor.update_editor() def _add_viewer_actions_menu_default(self): actions = [] for viewer_plugin in find_viewer_plugins(): actions += [ Action(name=viewer_plugin.__name__, action='handler.add_viewer(editor,object,"%s")' % viewer_plugin.__name__) ] return Menu(name='Add', *actions) def remove_viewer(self, editor, viewer_object): viewers = editor._menu_parent_object.viewers_instance viewers._remove_viewer(viewer_object) editor.update_editor() def add_viewer(self, editor, object, viewer_name): new_viewer = get_viewer_plugin_by_name(viewer_name)() object.viewers_instance._add_viewer(new_viewer) editor.update_editor()
class GenericModule(Module): """ Defines a GenericModule which is a collection of mayavi filters/components put together. This is very convenient and useful to create new modules. Note that all components including the actor must be passed as a list to set the components trait. """ # The *optional* Contour component to which we must listen to if # any. This is needed for modules that use a contour component # because when we turn on filled contours the mapper must switch to # use cell data. contour = Instance('enthought.mayavi.components.contour.Contour', allow_none=True) # The *optional* Actor component for which the LUT must be set. If # None is specified here, we will attempt to automatically determine # it. actor = Instance(Actor, allow_none=True) # Should we use the scalar LUT or the vector LUT? lut_mode = Enum('scalar', 'vector') ######################################## # Private traits. # Is the pipeline ready? Used internally. _pipeline_ready = Bool(False) ###################################################################### # `object` interface. ###################################################################### def __get_pure_state__(self): # Need to pickle the components. d = super(GenericModule, self).__get_pure_state__() d['components'] = self.components d.pop('_pipeline_ready', None) return d def __set_pure_state__(self, state): # If we are already running, there is a problem since the # components will be started automatically in the module's # handle_components even though their state is not yet set call # so we disable it here and restart it later. running = self.running self.running = False # Remove the actor states since we don't want these unpickled. actor_st = state.pop('actor', None) contour_st = state.pop('contour', None) # Create and set the components. handle_children_state(self.components, state.components) components = self.components # Restore our state using set_state. state_pickler.set_state(self, state) # Now set our actor and component by finding the right one to get from # the state. if actor_st is not None: for cst, c in zip(state.components, components): actor = find_object_given_state(actor_st, cst, c) if actor is not None: self.actor = actor break if contour_st is not None: for cst, c in zip(state.components, components): contour = find_object_given_state(contour_st, cst, c) if contour is not None: self.contour = contour break # Now start all components if needed. self._start_components() self.running = running ###################################################################### # `HasTraits` interface. ###################################################################### def default_traits_view(self): """Returns the default traits view for this object.""" le = ListEditor(use_notebook=True, deletable=False, export='DockWindowShell', page_name='.name') view = View(Group(Item(name='components', style='custom', show_label=False, editor=le, resizable=True), show_labels=False), resizable=True) return view ###################################################################### # `Module` interface. ###################################################################### def setup_pipeline(self): """Setup the pipeline.""" # Needed because a user may have setup the components by setting # the default value of the trait in the subclass in which case # the components_changed handler will never be called leading to # problems. if len(self.components) > 0 and not self._pipeline_ready: self._components_changed([], self.components) def update_pipeline(self): """This method *updates* the tvtk pipeline when data upstream is known to have changed. This method is invoked (automatically) when the input fires a `pipeline_changed` event. """ mm = self.module_manager if mm is None: return self._setup_pipeline() # Propagate the event. self.pipeline_changed = True def update_data(self): """This method does what is necessary when upstream data changes. This method is invoked (automatically) when any of the inputs sends a `data_changed` event. """ # Propagate the data_changed event. self.data_changed = True ###################################################################### # Private interface. ###################################################################### def _setup_pipeline(self): """Sets up the objects in the pipeline.""" mm = self.module_manager if mm is None or len(self.components) == 0: return # Our input. my_input = mm.source components = self.components if not self._pipeline_ready: # Hook up our first component. first = self.components[0] first.inputs = [my_input] # Hook up the others to each other. for i in range(1, len(components)): component = components[i] component.inputs = [components[i - 1]] self._pipeline_ready = True # Start components. self._start_components() # Setup the LUT of any actors. self._lut_mode_changed(self.lut_mode) def _handle_components(self, removed, added): super(GenericModule, self)._handle_components(removed, added) for component in added: if len(component.name) == 0: component.name = component.__class__.__name__ if self.actor is None: if isinstance(component, Actor): self.actor = component if len(self.components) == 0: self.input_info.datasets = ['none'] else: self.input_info.copy_traits(self.components[0].input_info) self._pipeline_ready = False self._setup_pipeline() def _lut_mode_changed(self, value): """Static traits listener.""" mm = self.module_manager if mm is None: return lm = mm.scalar_lut_manager if value == 'vector': lm = mm.vector_lut_manager if self.actor is not None: self.actor.set_lut(lm.lut) def _actor_changed(self, old, new): self._lut_mode_changed(self.lut_mode) def _filled_contours_changed_for_contour(self, value): """When filled contours are enabled, the mapper should use the the cell data, otherwise it should use the default scalar mode. """ if self.actor is None: return if value: self.actor.mapper.scalar_mode = 'use_cell_data' else: self.actor.mapper.scalar_mode = 'default' self.render() def _start_components(self): for component in self.components: if len(component.inputs) > 0 and \ len(component.inputs[0].outputs) > 0: component.start()