class BaseFilter(HasTraits): kw = Dict({}, desc = 'additional keyword arguments to nd filter') name_ = Property(Str) def process(self, image): return image def _get_name_(self): return self.__class__.__name__
class OptProblem(ArchitectureAssembly): """Class for specifying test problems for optimization algorithms and architectures""" solution = Dict({}, iotype="in", desc="dictionary of expected values for " "all des_vars and coupling_vars") def check_solution(self, strict=False): """return dictionary errors (actual-expected) of all des_vars, coupling_vars, and objectives strict: Boolean (optional) If True, then an error will be raised for any des_var, coupling_var, or objective where no solution is provided. If False, missing items are ignored. Defaults to False. """ error = {} try: for k, v in self.get_parameters().iteritems(): sol = self.solution[k] error[k] = v.evaluate() - sol except KeyError: if strict: self.raise_exception( "No solution was given for the des_var %s" % str(k), ValueError) else: pass try: for k, v in self.get_coupling_vars().iteritems(): sol = self.solution[k] error[k] = (v.indep.evaluate() - sol, v.dep.evaluate() - sol) except KeyError: if strict: self.raise_exception( "No solution was given for the coupling_var %s" % str(k), ValueError) else: pass try: for k, v in self.get_objectives().iteritems(): sol = self.solution[k] error[k] = v.evaluate() - sol except KeyError: if strict: self.raise_exception( "No solution was given for the objective %s" % str(k), ValueError) else: pass return error
class C(HasTraits): # A dict trait containing a list trait a = Dict(Int, List(Int)) # And we must initialize it to something non-trivial def __init__(self): super(C, self).__init__() self.a = {1: [2, 3]}
class MyClass(HasTraits): """ A dummy HasTraits class with a Dict """ d = Dict({"a": "apple", "b": "banana", "c": "cherry", "d": "durian"}) def __init__(self, callback): "The callback is called with the TraitDictEvent instance" self.callback = callback return def _d_items_changed(self, event): if self.callback: self.callback(event) return
class C(HasTraits): d = Dict(String, Instance(Parameter), {'p1':Parameter(name='p1',value=0.1), 'p2':Parameter(name='p2',value=0.2)}) d_view = List(Instance(Parameter)) def _d_view_default(self): return self.d.values() @on_trait_change('d.value') def update_view(self): print 'updating view' self.d_view = self.d.values() @on_trait_change('d_view.value') def update_d(self): print 'updating d' view = View(Item('d_view', editor=Parameter.editor))
class GraphEdge(HasPrivateTraits): """ Defines a representation of a graph edge for use by the graph editor and the graph editor factory classes. """ # head_nodes = List( Instance(HasTraits) ) head_name = Str # tail_nodes = List( Instance(HasTraits) ) tail_name = Str # List of object classes and/or interfaces that the edge applies to. edge_for = List(Any) # Dot attributes to be applied to the edge. dot_attr = Dict(Str, Any)
class ITVTKActorModel(HasTraits): """ An interface for view models that can control a TVTK scene's contents. """ # This maintains a dictionary mapping objects (by identity) to lists (or # single items) of TVTK Actors or 3D Widgets that represent them in the # scene. Adding and removing objects from this dictionary adds and removes # them from the scene. This is the trait that will be edited by a # ActorEditor. actor_map = Dict() # Turn off rendering such that multiple adds/removes can be refreshed at # once. disable_render = Bool(False) # Send this event in order to force a rendering of the scene. do_render = Event()
class ActorEditor(BasicEditorFactory): """ An editor factory for TVTK scenes. """ # The class of the editor object to be constructed. klass = _ActorEditor # The class or factory function for creating the actual scene object. scene_class = Callable(DecoratedScene) # Keyword arguments to pass to the scene factory. scene_kwds = Dict() # The name of the trait used for ITVTKActorModel.disable_render. disable_render_name = Str('disable_render') # The name of the trait used for ITVTKActorModel.do_render. do_render_name = Str('do_render')
class ActorModel(ITVTKActorModel): # A simple trait to change the actors/widgets. actor_type = Enum('cone', 'sphere', 'plane_widget', 'box_widget') ######################### # ITVTKView Model traits. # This maintains a dictionary mapping objects (by identity) to lists (or # single items) of TVTK Actors or 3D Widgets that represent them in the # scene. Adding and removing objects from this dictionary adds and removes # them from the scene. This is the trait that will be edited by a # ActorEditor. actor_map = Dict() ###################### view = View( Item(name='actor_type'), Item(name='actor_map', editor=ActorEditor(scene_kwds={'background': (0.2, 0.2, 0.2)}), show_label=False, resizable=True, height=500, width=500)) def __init__(self, **traits): super(ActorModel, self).__init__(**traits) self._actor_type_changed(self.actor_type) #################################### # Private traits. def _actor_type_changed(self, value): if value == 'cone': a = actors.cone_actor() self.actor_map = {'cone': a} elif value == 'sphere': a = actors.sphere_actor() self.actor_map = {'sphere': a} elif value == 'plane_widget': w = tvtk.PlaneWidget() self.actor_map = {'plane_widget': w} elif value == 'box_widget': w = tvtk.BoxWidget() self.actor_map = {'box_widget': w}
class WorkbenchWindowMemento(HasTraits): """ A memento for a workbench window. """ # The Id of the active perspective. active_perspective_id = Str # The memento for the editor area. editor_area_memento = Any # Mementos for each perspective that has been seen. # # The keys are the perspective Ids, the values are the toolkit-specific # mementos. perspective_mementos = Dict(Str, Any) # The position of the window. position = Tuple # The size of the window. size = Tuple
class TrimCase(HasTraits): # Instance(RunCase) runcase = Any() type = Trait('horizontal flight', {'horizontal flight':'c1', 'looping flight':'c2'}) parameters = Dict(String, Instance(Parameter)) parameter_view = List(Parameter, []) traits_view = View(Group(Item('type')), Group(Item('parameter_view', editor=Parameter.editor, show_label=False), label='Parameters'), Item(), # so that groups are not tabbed kind='livemodal', buttons=['OK'] ) #@on_trait_change('type,parameters.value') def update_parameters_from_avl(self): #print 'in update_parameters_from_avl' avl = self.runcase.avl avl.sendline('oper') avl.expect(AVL.patterns['/oper']) avl.sendline(self.type_) avl.expect(AVL.patterns['/oper/m']) constraint_lines = [line.strip() for line in avl.before.splitlines()] i1 = constraint_lines.index('=================================================') constraint_lines = constraint_lines[i1:] #print constraint_lines groups = [re.search(RunCase.patterns['parameter'], line) for line in constraint_lines] params = {} for group in groups: if group is not None: group = group.groupdict() pattern = group['pattern'] name = pattern unit = group.get('unit', '') unit = unit if unit is not None else '' params[name] = Parameter(name=name, pattern=pattern, cmd=group['cmd'], unit=unit, value=float(group['val'])) AVL.goto_state(avl) self.parameters = params self.parameter_view = sorted(params.values(), key=lambda x:x.name.upper()) return self
class PyContext(Context, Referenceable): """ A naming context for a Python namespace. """ #### 'Context' interface ################################################## # The naming environment in effect for this context. environment = Dict(ENVIRONMENT) #### 'PyContext' interface ################################################ # The Python namespace that we represent. namespace = Any # If the namespace is actual a Python object that has a '__dict__' # attribute, then this will be that object (the namespace will be the # object's '__dict__'. obj = Any #### 'Referenceable' interface ############################################ # The object's reference suitable for binding in a naming context. reference = Property(Instance(Reference)) ########################################################################### # 'object' interface. ########################################################################### def __init__(self, **traits): """ Creates a new context. """ # Base class constructor. super(PyContext, self).__init__(**traits) if type(self.namespace) is not dict: if hasattr(self.namespace, '__dict__'): self.obj = self.namespace self.namespace = self.namespace.__dict__ else: raise ValueError('Need a dictionary or a __dict__ attribute') return ########################################################################### # 'Referenceable' interface. ########################################################################### #### Properties ########################################################### def _get_reference(self): """ Returns a reference to this object suitable for binding. """ reference = Reference( class_name=self.__class__.__name__, addresses=[Address(type='py_context', content=self.namespace)]) return reference ########################################################################### # Protected 'Context' interface. ########################################################################### def _is_bound(self, name): """ Is a name bound in this context? """ return self.namespace.has_key(name) def _lookup(self, name): """ Looks up a name in this context. """ obj = self.namespace[name] return naming_manager.get_object_instance(obj, name, self) def _bind(self, name, obj): """ Binds a name to an object in this context. """ state = naming_manager.get_state_to_bind(obj, name, self) self.namespace[name] = state # Trait event notification. # An "added" event is fired by the bind method of the base calss (which calls # this one), so we don't need to do the changed here (which would be the wrong # thing anyway) -- LGV # # self.trait_property_changed('context_changed', None, None) return def _rebind(self, name, obj): """ Rebinds a name to a object in this context. """ self._bind(name, obj) return def _unbind(self, name): """ Unbinds a name from this context. """ del self.namespace[name] # Trait event notification. self.trait_property_changed('context_changed', None, None) return def _rename(self, old_name, new_name): """ Renames an object in this context. """ state = self.namespace[old_name] # Bind the new name. self.namespace[new_name] = state # Unbind the old one. del self.namespace[old_name] # Trait event notification. self.context_changed = True return def _create_subcontext(self, name): """ Creates a sub-context of this context. """ sub = self._context_factory(name, {}) self.namespace[name] = sub # Trait event notification. self.trait_property_changed('context_changed', None, None) return sub def _destroy_subcontext(self, name): """ Destroys a sub-context of this context. """ del self.namespace[name] # Trait event notification. self.trait_property_changed('context_changed', None, None) return def _list_bindings(self): """ Lists the bindings in this context. """ bindings = [] for name, value in self.namespace.items(): bindings.append( Binding(name=name, obj=self._lookup(name), context=self)) return bindings def _list_names(self): """ Lists the names bound in this context. """ return self.namespace.keys() ########################################################################### # Private interface. ########################################################################### def _context_factory(self, name, namespace): """ Create a sub-context. """ return self.__class__(namespace=namespace)
class Context(HasTraits): """ The base class for all naming contexts. """ # Keys for environment properties. INITIAL_CONTEXT_FACTORY = INITIAL_CONTEXT_FACTORY OBJECT_FACTORIES = OBJECT_FACTORIES STATE_FACTORIES = STATE_FACTORIES # Non-JNDI. TYPE_MANAGER = TYPE_MANAGER #### 'Context' interface ################################################## # The naming environment in effect for this context. environment = Dict(ENVIRONMENT) # The name of the context within its own namespace. namespace_name = Property(Str) # The type manager in the context's environment (used to create context # adapters etc.). # # fixme: This is an experimental 'convenience' trait, since it is common # to get hold of the context's type manager to see if some object has a # context adapter. type_manager = Property(Instance(TypeManager)) #### Events #### # Fired when an object has been added to the context (either via 'bind' or # 'create_subcontext'). object_added = Event(NamingEvent) # Fired when an object has been changed (via 'rebind'). object_changed = Event(NamingEvent) # Fired when an object has been removed from the context (either via # 'unbind' or 'destroy_subcontext'). object_removed = Event(NamingEvent) # Fired when an object in the context has been renamed (via 'rename'). object_renamed = Event(NamingEvent) # Fired when the contents of the context have changed dramatically. context_changed = Event(NamingEvent) #### Protected 'Context' interface ####################################### # The bindings in the context. _bindings = Dict(Str, Any) ########################################################################### # 'Context' interface. ########################################################################### #### Properties ########################################################### def _get_namespace_name(self): """ Return the name of the context within its own namespace. That is the full-path, through the namespace this context participates in, to get to this context. For example, if the root context of the namespace was called 'Foo', and there was a subcontext of that called 'Bar', and we were within that and called 'Baz', then this should return 'Foo/Bar/Baz'. """ # FIXME: We'd like to raise an exception and force implementors to # decide what to do. However, it appears to be pretty common that # most Context implementations do not override this method -- possibly # because the comments aren't clear on what this is supposed to be? # # Anyway, if we raise an exception then it is impossible to use any # evaluations when building a Traits UI for a Context. That is, the # Traits UI can't include items that have a 'visible_when' or # 'enabled_when' evaluation. This is because the Traits evaluation # code calls the 'get()' method on the Context which attempts to # retrieve the current namespace_name value. #raise OperationNotSupportedError() return '' def _get_type_manager(self): """ Returns the type manager in the context's environment. This will return None if no type manager was used to create the initial context. """ return self.environment.get(self.TYPE_MANAGER) #### Methods ############################################################## def bind(self, name, obj, make_contexts=False): """ Binds a name to an object. If 'make_contexts' is True then any missing intermediate contexts are created automatically. """ if len(name) == 0: raise InvalidNameError('empty name') # Parse the name. components = self._parse_name(name) # If there is exactly one component in the name then the operation # takes place in this context. if len(components) == 1: atom = components[0] # Is the name already bound? if self._is_bound(atom): raise NameAlreadyBoundError(name) # Do the actual bind. self._bind(atom, obj) # Trait event notification. self.object_added = NamingEvent( new_binding=Binding(name=name, obj=obj, context=self)) # Otherwise, attempt to continue resolution into the next context. else: if not self._is_bound(components[0]): if make_contexts: self._create_subcontext(components[0]) else: raise NameNotFoundError(components[0]) next_context = self._get_next_context(components[0]) next_context.bind('/'.join(components[1:]), obj, make_contexts) return def rebind(self, name, obj, make_contexts=False): """ Binds an object to a name that may already be bound. If 'make_contexts' is True then any missing intermediate contexts are created automatically. The object may be a different object but may also be the same object that is already bound to the specified name. The name may or may not be already used. Think of this as a safer version of 'bind' since this one will never raise an exception regarding a name being used. """ if len(name) == 0: raise InvalidNameError('empty name') # Parse the name. components = self._parse_name(name) # If there is exactly one component in the name then the operation # takes place in this context. if len(components) == 1: # Do the actual rebind. self._rebind(components[0], obj) # Trait event notification. self.object_changed = NamingEvent( new_binding=Binding(name=name, obj=obj, context=self)) # Otherwise, attempt to continue resolution into the next context. else: if not self._is_bound(components[0]): if make_contexts: self._create_subcontext(components[0]) else: raise NameNotFoundError(components[0]) next_context = self._get_next_context(components[0]) next_context.rebind('/'.join(components[1:]), obj, make_contexts) return def unbind(self, name): """ Unbinds a name. """ if len(name) == 0: raise InvalidNameError('empty name') # Parse the name. components = self._parse_name(name) # If there is exactly one component in the name then the operation # takes place in this context. if len(components) == 1: atom = components[0] if not self._is_bound(atom): raise NameNotFoundError(name) # Lookup the object that we are unbinding to use in the event # notification. obj = self._lookup(atom) # Do the actual unbind. self._unbind(atom) # Trait event notification. self.object_removed = NamingEvent( old_binding=Binding(name=name, obj=obj, context=self)) # Otherwise, attempt to continue resolution into the next context. else: if not self._is_bound(components[0]): raise NameNotFoundError(components[0]) next_context = self._get_next_context(components[0]) next_context.unbind('/'.join(components[1:])) return def rename(self, old_name, new_name): """ Binds a new name to an object. """ if len(old_name) == 0 or len(new_name) == 0: raise InvalidNameError('empty name') # Parse the names. old_components = self._parse_name(old_name) new_components = self._parse_name(new_name) # If there is axactly one component in BOTH names then the operation # takes place ENTIRELY in this context. if len(old_components) == 1 and len(new_components) == 1: # Is the old name actually bound? if not self._is_bound(old_name): raise NameNotFoundError(old_name) # Is the new name already bound? if self._is_bound(new_name): raise NameAlreadyBoundError(new_name) # Do the actual rename. self._rename(old_name, new_name) # Lookup the object that we are renaming to use in the event # notification. obj = self._lookup(new_name) # Trait event notification. self.object_renamed = NamingEvent( old_binding=Binding(name=old_name, obj=obj, context=self), new_binding=Binding(name=new_name, obj=obj, context=self)) else: # fixme: This really needs to be transactional in case the bind # succeeds but the unbind fails. To be safe should we just not # support cross-context renaming for now?!?! # # Lookup the object. obj = self.lookup(old_name) # Bind the new name. self.bind(new_name, obj) # Unbind the old one. self.unbind(old_name) return def lookup(self, name): """ Resolves a name relative to this context. """ # If the name is empty we return the context itself. if len(name) == 0: # fixme: The JNDI spec. says that this should return a COPY of # the context. return self # Parse the name. components = self._parse_name(name) # If there is exactly one component in the name then the operation # takes place in this context. if len(components) == 1: atom = components[0] if not self._is_bound(atom): raise NameNotFoundError(name) # Do the actual lookup. obj = self._lookup(atom) # Otherwise, attempt to continue resolution into the next context. else: if not self._is_bound(components[0]): raise NameNotFoundError(components[0]) next_context = self._get_next_context(components[0]) obj = next_context.lookup('/'.join(components[1:])) return obj # fixme: Non-JNDI def lookup_binding(self, name): """ Looks up the binding for a name relative to this context. """ if len(name) == 0: raise InvalidNameError('empty name') # Parse the name. components = self._parse_name(name) # If there is exactly one component in the name then the operation # takes place in this context. if len(components) == 1: atom = components[0] if not self._is_bound(atom): raise NameNotFoundError(name) # Do the actual lookup. binding = self._lookup_binding(atom) # Otherwise, attempt to continue resolution into the next context. else: if not self._is_bound(components[0]): raise NameNotFoundError(components[0]) next_context = self._get_next_context(components[0]) binding = next_context.lookup_binding('/'.join(components[1:])) return binding # fixme: Non-JNDI def lookup_context(self, name): """ Resolves a name relative to this context. The name MUST resolve to a context. This method is useful to return context adapters. """ # If the name is empty we return the context itself. if len(name) == 0: # fixme: The JNDI spec. says that this should return a COPY of # the context. return self # Parse the name. components = self._parse_name(name) # If there is exactly one component in the name then the operation # takes place in this context. if len(components) == 1: atom = components[0] if not self._is_bound(atom): raise NameNotFoundError(name) # Do the actual lookup. obj = self._get_next_context(atom) # Otherwise, attempt to continue resolution into the next context. else: if not self._is_bound(components[0]): raise NameNotFoundError(components[0]) next_context = self._get_next_context(components[0]) obj = next_context.lookup('/'.join(components[1:])) return obj def create_subcontext(self, name): """ Creates a sub-context. """ if len(name) == 0: raise InvalidNameError('empty name') # Parse the name. components = self._parse_name(name) # If there is exactly one component in the name then the operation # takes place in this context. if len(components) == 1: atom = components[0] # Is the name already bound? if self._is_bound(atom): raise NameAlreadyBoundError(name) # Do the actual creation of the sub-context. sub = self._create_subcontext(atom) # Trait event notification. self.object_added = NamingEvent( new_binding=Binding(name=name, obj=sub, context=self)) # Otherwise, attempt to continue resolution into the next context. else: if not self._is_bound(components[0]): raise NameNotFoundError(components[0]) next_context = self._get_next_context(components[0]) sub = next_context.create_subcontext('/'.join(components[1:])) return sub def destroy_subcontext(self, name): """ Destroys a sub-context. """ if len(name) == 0: raise InvalidNameError('empty name') # Parse the name. components = self._parse_name(name) # If there is exactly one component in the name then the operation # takes place in this context. if len(components) == 1: atom = components[0] if not self._is_bound(atom): raise NameNotFoundError(name) obj = self._lookup(atom) if not self._is_context(atom): raise NotContextError(name) # Do the actual destruction of the sub-context. self._destroy_subcontext(atom) # Trait event notification. self.object_removed = NamingEvent( old_binding=Binding(name=name, obj=obj, context=self)) # Otherwise, attempt to continue resolution into the next context. else: if not self._is_bound(components[0]): raise NameNotFoundError(components[0]) next_context = self._get_next_context(components[0]) next_context.destroy_subcontext('/'.join(components[1:])) return # fixme: Non-JNDI def get_unique_name(self, prefix): """ Returns a name that is unique within the context. The name returned will start with the specified prefix. """ return make_unique_name(prefix, existing=self.list_names(''), format='%s (%d)') def list_names(self, name=''): """ Lists the names bound in a context. """ # If the name is empty then the operation takes place in this context. if len(name) == 0: names = self._list_names() # Otherwise, attempt to continue resolution into the next context. else: # Parse the name. components = self._parse_name(name) if not self._is_bound(components[0]): raise NameNotFoundError(components[0]) next_context = self._get_next_context(components[0]) names = next_context.list_names('/'.join(components[1:])) return names def list_bindings(self, name=''): """ Lists the bindings in a context. """ # If the name is empty then the operation takes place in this context. if len(name) == 0: bindings = self._list_bindings() # Otherwise, attempt to continue resolution into the next context. else: # Parse the name. components = self._parse_name(name) if not self._is_bound(components[0]): raise NameNotFoundError(components[0]) next_context = self._get_next_context(components[0]) bindings = next_context.list_bindings('/'.join(components[1:])) return bindings # fixme: Non-JNDI def is_context(self, name): """ Returns True if the name is bound to a context. """ # If the name is empty then it refers to this context. if len(name) == 0: is_context = True else: # Parse the name. components = self._parse_name(name) # If there is exactly one component in the name then the operation # takes place in this context. if len(components) == 1: atom = components[0] if not self._is_bound(atom): raise NameNotFoundError(name) # Do the actual check. is_context = self._is_context(atom) # Otherwise, attempt to continue resolution into the next context. else: if not self._is_bound(components[0]): raise NameNotFoundError(components[0]) next_context = self._get_next_context(components[0]) is_context = next_context.is_context('/'.join(components[1:])) return is_context # fixme: Non-JNDI def search(self, obj): """ Returns a list of namespace names that are bound to obj. """ # don't look for None if obj is None: return [] # Obj is bound to these names relative to this context names = [] # path contain the name components down to the current context path = [] self._search(obj, names, path, {}) return names ########################################################################### # Protected 'Context' interface. ########################################################################### def _parse_name(self, name): """ Parse a name into a list of components. e.g. 'foo/bar/baz' -> ['foo', 'bar', 'baz'] """ return name.split('/') def _is_bound(self, name): """ Is a name bound in this context? """ return name in self._bindings def _lookup(self, name): """ Looks up a name in this context. """ obj = self._bindings[name] return naming_manager.get_object_instance(obj, name, self) def _lookup_binding(self, name): """ Looks up the binding for a name in this context. """ return Binding(name=name, obj=self._lookup(name), context=self) def _bind(self, name, obj): """ Binds a name to an object in this context. """ state = naming_manager.get_state_to_bind(obj, name, self) self._bindings[name] = state return def _rebind(self, name, obj): """ Rebinds a name to an object in this context. """ self._bind(name, obj) return def _unbind(self, name): """ Unbinds a name from this context. """ del self._bindings[name] return def _rename(self, old_name, new_name): """ Renames an object in this context. """ # Bind the new name. self._bindings[new_name] = self._bindings[old_name] # Unbind the old one. del self._bindings[old_name] return def _create_subcontext(self, name): """ Creates a sub-context of this context. """ sub = self.__class__(environment=self.environment) self._bindings[name] = sub return sub def _destroy_subcontext(self, name): """ Destroys a sub-context of this context. """ del self._bindings[name] return def _list_bindings(self): """ Lists the bindings in this context. """ bindings = [] for name in self._list_names(): bindings.append( Binding(name=name, obj=self._lookup(name), context=self)) return bindings def _list_names(self): """ Lists the names bound in this context. """ return self._bindings.keys() def _is_context(self, name): """ Returns True if a name is bound to a context. """ return self._get_next_context(name) is not None def _get_next_context(self, name): """ Returns the next context. """ obj = self._lookup(name) # If the object is a context then everything is just dandy. if isinstance(obj, Context): next_context = obj # Otherwise, instead of just giving up, see if the context has a type # manager that knows how to adapt the object to make it quack like a # context. else: next_context = self._get_context_adapter(obj) # If no adapter was found then we cannot continue name resolution. if next_context is None: raise NotContextError(name) return next_context def _search(self, obj, names, path, searched): """ Append to names any name bound to obj. Join path and name with '/' to for a complete name from the top context. """ # Check the bindings recursively. for binding in self.list_bindings(): if binding.obj is obj: path.append(binding.name) names.append('/'.join(path)) path.pop() if isinstance( binding.obj, Context ) \ and not binding.obj in searched: path.append(binding.name) searched[binding.obj] = True binding.obj._search(obj, names, path, searched) path.pop() return ########################################################################### # Private interface. ########################################################################### def _get_context_adapter(self, obj): """ Returns a context adapter for an object. Returns None if no such adapter is available. """ if self.type_manager is not None: adapter = self.type_manager.object_as(obj, Context, environment=self.environment, context=self) else: adapter = None return adapter
class UnstructuredGridReader(FileDataSource): # The version of this class. Used for persistence. __version__ = 0 # The UnstructuredGridAlgorithm data file reader. reader = Instance(tvtk.Object, allow_none=False, record=True) # Information about what this object can produce. output_info = PipelineInfo(datasets=['unstructured_grid']) ###################################################################### # Private Traits _reader_dict = Dict(Str, Instance(tvtk.Object)) # Our view. view = View(Group(Include('time_step_group'), Item(name='base_file_name'), Item(name='reader', style='custom', resizable=True), show_labels=False), resizable=True) ###################################################################### # `object` interface ###################################################################### def __set_pure_state__(self, state): # The reader has its own file_name which needs to be fixed. state.reader.file_name = state.file_path.abs_pth # Now call the parent class to setup everything. super(UnstructuredGridReader, self).__set_pure_state__(state) ###################################################################### # `FileDataSource` interface ###################################################################### def update(self): self.reader.update() if len(self.file_path.get()) == 0: return self.render() ###################################################################### # Non-public interface ###################################################################### def _file_path_changed(self, fpath): value = fpath.get() if len(value) == 0: return # Extract the file extension splitname = value.strip().split('.') extension = splitname[-1].lower() # Select UnstructuredGridreader based on file type old_reader = self.reader if self._reader_dict.has_key(extension): self.reader = self._reader_dict[extension] else: error('Invalid file extension for file: %s' % value) return self.reader.file_name = value.strip() self.reader.update() self.reader.update_information() if old_reader is not None: old_reader.on_trait_change(self.render, remove=True) self.reader.on_trait_change(self.render) old_outputs = self.outputs self.outputs = [self.reader.output] if self.outputs == old_outputs: self.data_changed = True # Change our name on the tree view self.name = self._get_name() def _get_name(self): """ Returns the name to display on the tree view. Note that this is not a property getter. """ fname = basename(self.file_path.get()) ret = "%s" % fname if len(self.file_list) > 1: ret += " (timeseries)" if '[Hidden]' in self.name: ret += ' [Hidden]' return ret def __reader_dict_default(self): """Default value for reader dict.""" rd = { 'inp': tvtk.AVSucdReader(), 'neu': tvtk.GAMBITReader(), 'exii': tvtk.ExodusReader() } return rd
class MyOtherClass(HasTraits): """ A dummy HasTraits class with a Dict """ d = Dict({"a": "apple", "b": "banana", "c": "cherry", "d": "durian"})
class InstanceFactoryChoice ( InstanceChoiceItem ): #--------------------------------------------------------------------------- # Trait definitions: #--------------------------------------------------------------------------- # Indicates whether an instance compatible with this item can be dragged and # dropped rather than created droppable = false # Indicates whether the item can be selected by the user selectable = true # A class (or other callable) that can be used to create an item compatible # with this item klass = Callable # Tuple of arguments to pass to **klass** to create an instance args = Tuple # Dictionary of arguments to pass to **klass** to create an instance kw_args = Dict( Str, Any ) # Does this item create new instances? This value overrides the default. is_factory = True #--------------------------------------------------------------------------- # Returns the name of the item: #--------------------------------------------------------------------------- def get_name ( self, object = None ): """ Returns the name of the item. """ if self.name != '': return self.name name = getattr( object, 'name', None ) if isinstance(name, basestring): return name if issubclass( type( self.klass ), type ): klass = self.klass else: klass = self.get_object().__class__ return user_name_for( klass.__name__ ) #--------------------------------------------------------------------------- # Returns the object associated with the item: #--------------------------------------------------------------------------- def get_object ( self ): """ Returns the object associated with the item. """ return self.klass( *self.args, **self.kw_args ) #--------------------------------------------------------------------------- # Indicates whether the item supports drag and drop: #--------------------------------------------------------------------------- def is_droppable ( self ): """ Indicates whether the item supports drag and drop. """ return self.droppable #--------------------------------------------------------------------------- # Indicates whether a specified object is compatible with the item: #--------------------------------------------------------------------------- def is_compatible ( self, object ): """ Indicates whether a specified object is compatible with the item. """ if issubclass( type( self.klass ), type ): return isinstance( object, self.klass ) return isinstance( object, self.get_object().__class__ ) #--------------------------------------------------------------------------- # Indicates whether the item can be selected by the user: #--------------------------------------------------------------------------- def is_selectable ( self ): """ Indicates whether the item can be selected by the user. """ return self.selectable
class Mesh(Base): group = Dict(key_trait=String, value_trait=Group) groupGroup = Dict(key_trait=String, value_trait=GroupGroup)
class ParametricSurface(Source): # The version of this class. Used for persistence. __version__ = 0 # Flag to set the parametric function type. function = Enum('boy', 'conic_spiral', 'cap', 'dini', 'ellipsoid', 'enneper', 'figure8klein', 'klein', 'mobius', 'random_hills', 'roman', 'spline', 'super_ellipsoid', 'super_toroid', 'torus', desc='which parametric function to be used') # Define the trait 'parametric_function' whose value must be an instance of # type ParametricFunction parametric_function = Instance(tvtk.ParametricFunction, allow_none=False, record=True) # The Parametric function source which generates the data. source = Instance(tvtk.ParametricFunctionSource, args=(), kw={'scalar_mode': 'distance'}, allow_none=False, record=True) # Information about what this object can produce. output_info = PipelineInfo(datasets=['poly_data'], attribute_types=['any'], attributes=['any']) ######################################## # Private traits. # A dictionary that maps the function names to instances of the # parametric surfaces _function_dict = Dict(Str, Instance(tvtk.ParametricFunction, allow_none=False)) ###################################################################### # `object` interface ###################################################################### def __init__(self, **traits): # Setup the function dict. fd = { 'boy': tvtk.ParametricBoy(), 'conic_spiral': tvtk.ParametricConicSpiral(), 'cap': tvtk.ParametricCrossCap(), 'dini': tvtk.ParametricDini(), 'ellipsoid': tvtk.ParametricEllipsoid(), 'enneper': tvtk.ParametricEnneper(), 'figure8klein': tvtk.ParametricFigure8Klein(), 'klein': tvtk.ParametricKlein(), 'mobius': tvtk.ParametricMobius(), 'random_hills': tvtk.ParametricRandomHills(), 'roman': tvtk.ParametricRoman(), 'spline': tvtk.ParametricSpline(), 'super_ellipsoid': tvtk.ParametricSuperEllipsoid(), 'super_toroid': tvtk.ParametricSuperToroid(), 'torus': tvtk.ParametricTorus() } self._function_dict = fd # Call parent class' init. super(ParametricSurface, self).__init__(**traits) # Initialize the function to the default mode's instance from # the dictionary self.parametric_function = self._function_dict[self.function] # Call render everytime source traits change. self.source.on_trait_change(self.render) # Setup the outputs. self.outputs = [self.source.output] ###################################################################### # Non-public methods. ###################################################################### def _function_changed(self, value): """This method is invoked (automatically) when the `function` trait is changed. """ self.parametric_function = self._function_dict[self.function] def _parametric_function_changed(self, old, new): """This method is invoked (automatically) when the `parametric_function` attribute is changed. """ self.source.parametric_function = self.parametric_function # Setup the handlers so that if old is not None: old.on_trait_change(self.render, remove=True) new.on_trait_change(self.render) self.data_changed = True
class Registry(HasTraits): """ This class is a registry for various engines, and metadata from sources, filters and modules """ # The mayavi engines used. engines = Dict(Str, Instance('enthought.mayavi.core.engine.Engine')) # The metadata for the sources. sources = List(Metadata) # The metadata for the modules. modules = List(Metadata) # The metadata for the filters. filters = List(Metadata) ###################################################################### # `Registry` interface. ###################################################################### def register_engine(self, engine, name=''): """Registers a mayavi engine with an optional name. Note that we allow registering an engine with the same name as another already registered. """ engines = self.engines if len(name) == 0: name = '%s%d' % (engine.__class__.__name__, len(engines) + 1) logger.debug('Engine [%s] named %s registered', engine, name) engines[name] = engine def unregister_engine(self, engine_or_name): """Unregisters a mayavi engine specified either as a name or an engine instance.""" engines = self.engines if isinstance(engine_or_name, str): name = engine_or_name else: for key, engine in engines.iteritems(): if engine_or_name == engine: name = key break del engines[name] logger.debug('Engine named %s unregistered', name) def get_file_reader(self, filename): """Given a filename, find a suitable source metadata that will read the file. Returns a suitable source metadata object that will handle this. """ base, ext = splitext(filename) result = [] if len(ext) > 0: ext = ext[1:] result = [src for src in self.sources \ if ext in src.extensions] # 'result' contains list of all source metadata that can handle # the file. # If there is only single source metadata available to handle # the file, we simply return it. # If there is a conflict i.e. more then one source metadata objects # capable of handling the file then we check if they are capable of # actually reading it using 'can_read_function' which may be a class # method or a simple function which returns whether the object is # capable of reading the file or not. # Finally returns the most suitable source metadata object to the engine. If # multiple objects are still present we return the last one in the list. if len(result) > 1: for res in result[:]: if len(res.can_read_test) > 0: can_read = import_symbol(res.can_read_test)(filename) if can_read: return res else: result.remove(res) if len(result) == 0: return None return result[-1] def find_scene_engine(self, scene): """ Find the engine corresponding to a given tvtk scene. """ for engine in self.engines.values(): for s in engine.scenes: if scene is s: return engine sc = s.scene if scene is sc: return engine elif hasattr(sc, 'scene_editor') and \ scene is sc.scene_editor: # This check is needed for scene model objects. return engine else: raise TypeError, "Scene not attached to a mayavi engine."
class SceneModel(TVTKScene): ######################################## # TVTKScene traits. light_manager = Property picker = Property ######################################## # SceneModel traits. # A convenient dictionary based interface to add/remove actors and widgets. # This is similar to the interface provided for the ActorEditor. actor_map = Dict() # This is used primarily to implement the add_actor/remove_actor methods. actor_list = List() # The actual scene being edited. scene_editor = Instance(TVTKScene) do_render = Event() # Fired when this is activated. activated = Event() # Fired when this widget is closed. closing = Event() # This exists just to mirror the TVTKWindow api. scene = Property ################################### # View related traits. # Render_window's view. _stereo_view = Group( Item(name='stereo_render'), Item(name='stereo_type'), show_border=True, label='Stereo rendering', ) # The default view of this object. default_view = View( Group(Group( Item(name='background'), Item(name='foreground'), Item(name='parallel_projection'), Item(name='disable_render'), Item(name='off_screen_rendering'), Item(name='jpeg_quality'), Item(name='jpeg_progressive'), Item(name='magnification'), Item(name='anti_aliasing_frames'), ), Group( Item(name='render_window', style='custom', visible_when='object.stereo', editor=InstanceEditor(view=View(_stereo_view)), show_label=False), ), label='Scene'), Group(Item(name='light_manager', style='custom', editor=InstanceEditor(), show_label=False), label='Lights')) ################################### # Private traits. # Used by the editor to determine if the widget was enabled or not. enabled_info = Dict() def __init__(self, parent=None, **traits): """ Initializes the object. """ # Base class constructor. We call TVTKScene's super here on purpose. # Calling TVTKScene's init will create a new window which we do not # want. super(TVTKScene, self).__init__(**traits) self.control = None ###################################################################### # TVTKScene API. ###################################################################### def render(self): """ Force the scene to be rendered. Nothing is done if the `disable_render` trait is set to True.""" self.do_render = True def add_actors(self, actors): """ Adds a single actor or a tuple or list of actors to the renderer.""" if hasattr(actors, '__iter__'): self.actor_list.extend(actors) else: self.actor_list.append(actors) def remove_actors(self, actors): """ Removes a single actor or a tuple or list of actors from the renderer.""" my_actors = self.actor_list if hasattr(actors, '__iter__'): for actor in actors: my_actors.remove(actor) else: my_actors.remove(actors) # Conevenience methods. add_actor = add_actors remove_actor = remove_actors def add_widgets(self, widgets, enabled=True): """Adds widgets to the renderer. """ if not hasattr(widgets, '__iter__'): widgets = [widgets] for widget in widgets: self.enabled_info[widget] = enabled self.add_actors(widgets) def remove_widgets(self, widgets): """Removes widgets from the renderer.""" if not hasattr(widgets, '__iter__'): widgets = [widgets] self.remove_actors(widgets) for widget in widgets: del self.enabled_info[widget] def reset_zoom(self): """Reset the camera so everything in the scene fits.""" if self.scene_editor is not None: self.scene_editor.reset_zoom() def save(self, file_name, size=None, **kw_args): """Saves rendered scene to one of several image formats depending on the specified extension of the filename. If an additional size (2-tuple) argument is passed the window is resized to the specified size in order to produce a suitably sized output image. Please note that when the window is resized, the window may be obscured by other widgets and the camera zoom is not reset which is likely to produce an image that does not reflect what is seen on screen. Any extra keyword arguments are passed along to the respective image format's save method. """ self._check_scene_editor() self.scene_editor.save(file_name, size, **kw_args) def save_ps(self, file_name): """Saves the rendered scene to a rasterized PostScript image. For vector graphics use the save_gl2ps method.""" self._check_scene_editor() self.scene_editor.save_ps(file_name) def save_bmp(self, file_name): """Save to a BMP image file.""" self._check_scene_editor() self.scene_editor.save_bmp(file_name) def save_tiff(self, file_name): """Save to a TIFF image file.""" self._check_scene_editor() self.scene_editor.save_tiff(file_name) def save_png(self, file_name): """Save to a PNG image file.""" self._check_scene_editor() self.scene_editor.save_png(file_name) def save_jpg(self, file_name, quality=None, progressive=None): """Arguments: file_name if passed will be used, quality is the quality of the JPEG(10-100) are valid, the progressive arguments toggles progressive jpegs.""" self._check_scene_editor() self.scene_editor.save_jpg(file_name, quality, progressive) def save_iv(self, file_name): """Save to an OpenInventor file.""" self._check_scene_editor() self.scene_editor.save_iv(file_name) def save_vrml(self, file_name): """Save to a VRML file.""" self._check_scene_editor() self.scene_editor.save_vrml(file_name) def save_oogl(self, file_name): """Saves the scene to a Geomview OOGL file. Requires VTK 4 to work.""" self._check_scene_editor() self.scene_editor.save_oogl(file_name) def save_rib(self, file_name, bg=0, resolution=None, resfactor=1.0): """Save scene to a RenderMan RIB file. Keyword Arguments: file_name -- File name to save to. bg -- Optional background option. If 0 then no background is saved. If non-None then a background is saved. If left alone (defaults to None) it will result in a pop-up window asking for yes/no. resolution -- Specify the resolution of the generated image in the form of a tuple (nx, ny). resfactor -- The resolution factor which scales the resolution. """ self._check_scene_editor() self.scene_editor.save_rib(file_name, bg, resolution, resfactor) def save_wavefront(self, file_name): """Save scene to a Wavefront OBJ file. Two files are generated. One with a .obj extension and another with a .mtl extension which contains the material proerties. Keyword Arguments: file_name -- File name to save to """ self._check_scene_editor() self.scene_editor.save_wavefront(file_name) def save_gl2ps(self, file_name, exp=None): """Save scene to a vector PostScript/EPS/PDF/TeX file using GL2PS. If you choose to use a TeX file then note that only the text output is saved to the file. You will need to save the graphics separately. Keyword Arguments: file_name -- File name to save to. exp -- Optionally configured vtkGL2PSExporter object. Defaults to None and this will use the default settings with the output file type chosen based on the extention of the file name. """ self._check_scene_editor() self.scene_editor.save_gl2ps(file_name, exp) def get_size(self): """Return size of the render window.""" self._check_scene_editor() return self.scene_editor.get_size() def set_size(self, size): """Set the size of the window.""" self._check_scene_editor() self.scene_editor.set_size(size) def _update_view(self, x, y, z, vx, vy, vz): """Used internally to set the view.""" if self.scene_editor is not None: self.scene_editor._update_view(x, y, z, vx, vy, vz) def _check_scene_editor(self): if self.scene_editor is None: msg = """ This method requires that there be an active scene editor. To do this, you will typically need to invoke:: object.edit_traits() where object is the object that contains the SceneModel. """ raise SceneModelError(msg) def _scene_editor_changed(self, old, new): if new is None: self._renderer = None self._renwin = None self._interactor = None else: self._renderer = new._renderer self._renwin = new._renwin self._interactor = new._interactor def _get_picker(self): """Getter for the picker.""" se = self.scene_editor if se is not None and hasattr(se, 'picker'): return se.picker return None def _get_light_manager(self): """Getter for the light manager.""" se = self.scene_editor if se is not None: return se.light_manager return None ###################################################################### # SceneModel API. ###################################################################### def _get_scene(self): """Getter for the scene property.""" return self
class GlyphSource(Component): # The version of this class. Used for persistence. __version__ = 1 # Glyph position. This can be one of ['head', 'tail', 'center'], # and indicates the position of the glyph with respect to the # input point data. Please note that this will work correctly # only if you do not mess with the source glyph's basic size. For # example if you use a ConeSource and set its height != 1, then the # 'head' and 'tail' options will not work correctly. glyph_position = Trait('center', TraitPrefixList(['head', 'tail', 'center']), desc='position of glyph w.r.t. data point') # The Source to use for the glyph. This is chosen from # `self._glyph_list` or `self.glyph_dict`. glyph_source = Instance(tvtk.Object, allow_none=False, record=True) # A dict of glyphs to use. glyph_dict = Dict(desc='the glyph sources to select from', record=False) # A list of predefined glyph sources that can be used. glyph_list = Property(List(tvtk.Object), record=False) ######################################## # Private traits. # The transformation to use to place glyph appropriately. _trfm = Instance(tvtk.TransformFilter, args=()) # Used for optimization. _updating = Bool(False) ######################################## # View related traits. view = View(Group( Group(Item(name='glyph_position')), Group(Item( name='glyph_source', style='custom', resizable=True, editor=InstanceEditor(name='glyph_list'), ), label='Glyph Source', show_labels=False)), resizable=True) ###################################################################### # `Base` interface ###################################################################### def __get_pure_state__(self): d = super(GlyphSource, self).__get_pure_state__() for attr in ('_updating', 'glyph_list'): d.pop(attr, None) return d def __set_pure_state__(self, state): if 'glyph_dict' in state: # Set their state. set_state(self, state, first=['glyph_dict'], ignore=['*']) ignore = ['glyph_dict'] else: # Set the dict state using the persisted list. gd = self.glyph_dict gl = self.glyph_list handle_children_state(gl, state.glyph_list) for g, gs in zip(gl, state.glyph_list): name = camel2enthought(g.__class__.__name__) if name not in gd: gd[name] = g # Set the glyph source's state. set_state(g, gs) ignore = ['glyph_list'] g_name = state.glyph_source.__metadata__['class_name'] name = camel2enthought(g_name) # Set the correct glyph_source. self.glyph_source = self.glyph_dict[name] set_state(self, state, ignore=ignore) ###################################################################### # `Component` interface ###################################################################### def setup_pipeline(self): """Override this method so that it *creates* the tvtk pipeline. This method is invoked when the object is initialized via `__init__`. Note that at the time this method is called, the tvtk data pipeline will *not* yet be setup. So upstream data will not be available. The idea is that you simply create the basic objects and setup those parts of the pipeline not dependent on upstream sources and filters. You should also set the `actors` attribute up at this point. """ self._trfm.transform = tvtk.Transform() # Setup the glyphs. self.glyph_source = self.glyph_dict['glyph_source2d'] def update_pipeline(self): """Override this method so that it *updates* the tvtk pipeline when data upstream is known to have changed. This method is invoked (automatically) when any of the inputs sends a `pipeline_changed` event. """ self._glyph_position_changed(self.glyph_position) self.pipeline_changed = True def update_data(self): """Override this method so that it flushes the vtk pipeline if that is necessary. This method is invoked (automatically) when any of the inputs sends a `data_changed` event. """ self.data_changed = True def render(self): if not self._updating: super(GlyphSource, self).render() ###################################################################### # Non-public methods. ###################################################################### def _glyph_source_changed(self, value): if self._updating == True: return gd = self.glyph_dict value_cls = camel2enthought(value.__class__.__name__) if value not in gd.values(): gd[value_cls] = value # Now change the glyph's source trait. self._updating = True recorder = self.recorder if recorder is not None: name = recorder.get_script_id(self) lhs = '%s.glyph_source' % name rhs = '%s.glyph_dict[%r]' % (name, value_cls) recorder.record('%s = %s' % (lhs, rhs)) name = value.__class__.__name__ if name == 'GlyphSource2D': self.outputs = [value.output] else: self._trfm.input = value.output self.outputs = [self._trfm.output] value.on_trait_change(self.render) self._updating = False # Now update the glyph position since the transformation might # be different. self._glyph_position_changed(self.glyph_position) def _glyph_position_changed(self, value): if self._updating == True: return self._updating = True tr = self._trfm.transform tr.identity() g = self.glyph_source name = g.__class__.__name__ # Compute transformation factor if name == 'CubeSource': tr_factor = g.x_length / 2.0 elif name == 'CylinderSource': tr_factor = -g.height / 2.0 elif name == 'ConeSource': tr_factor = g.height / 2.0 elif name == 'SphereSource': tr_factor = g.radius else: tr_factor = 1. # Translate the glyph if value == 'tail': if name == 'GlyphSource2D': g.center = 0.5, 0.0, 0.0 elif name == 'ArrowSource': pass elif name == 'CylinderSource': g.center = 0, tr_factor, 0.0 elif hasattr(g, 'center'): g.center = tr_factor, 0.0, 0.0 elif value == 'head': if name == 'GlyphSource2D': g.center = -0.5, 0.0, 0.0 elif name == 'ArrowSource': tr.translate(-1, 0, 0) elif name == 'CylinderSource': g.center = 0, -tr_factor, 0.0 else: g.center = -tr_factor, 0.0, 0.0 else: if name == 'ArrowSource': tr.translate(-0.5, 0, 0) elif name != 'Axes': g.center = 0.0, 0.0, 0.0 if name == 'CylinderSource': tr.rotate_z(90) self._updating = False self.render() def _get_glyph_list(self): # Return the glyph list as per the original order in earlier # implementation. order = [ 'glyph_source2d', 'arrow_source', 'cone_source', 'cylinder_source', 'sphere_source', 'cube_source', 'axes' ] gd = self.glyph_dict for key in gd: if key not in order: order.append(key) return [gd[key] for key in order] def _glyph_dict_default(self): g = { 'glyph_source2d': tvtk.GlyphSource2D(glyph_type='arrow', filled=False), 'arrow_source': tvtk.ArrowSource(), 'cone_source': tvtk.ConeSource(height=1.0, radius=0.2, resolution=15), 'cylinder_source': tvtk.CylinderSource(height=1.0, radius=0.15, resolution=10), 'sphere_source': tvtk.SphereSource(), 'cube_source': tvtk.CubeSource(), 'axes': tvtk.Axes(symmetric=1) } return g
class MousePickDispatcher(HasTraits): """ An event dispatcher to send pick event on mouse clicks. This objects wires VTK observers so that picking callbacks can be bound to mouse click without movement. The object deals with adding and removing the VTK-level callbacks. """ # The scene events are wired to. scene = Instance(Scene) # The list of callbacks, with the picker type they should be using, # and the mouse button that triggers them. callbacks = List(Tuple( Callable, Enum('cell', 'point', 'world'), Enum('Left', 'Middle', 'Right'), ), help="The list of callbacks, with the picker type they " "should be using, and the mouse button that " "triggers them. The callback is passed " "as an argument the tvtk picker.") #-------------------------------------------------------------------------- # Private traits #-------------------------------------------------------------------------- # Whether the mouse has moved after the button press _mouse_no_mvt = Int # The button that has been pressed _current_button = Enum('Left', 'Middle', 'Right') # The various picker that are used when the mouse is pressed _active_pickers = Dict # The VTK callback numbers corresponding to our callbacks _picker_callback_nbs = Dict(value_trait=Int) # The VTK callback numbers corresponding to mouse movement _mouse_mvt_callback_nb = Int # The VTK callback numbers corresponding to mouse press _mouse_press_callback_nbs = Dict # The VTK callback numbers corresponding to mouse release _mouse_release_callback_nbs = Dict #-------------------------------------------------------------------------- # Callbacks management #-------------------------------------------------------------------------- @on_trait_change('callbacks_items') def dispatch_callbacks_change(self, name, trait_list_event): for item in trait_list_event.added: self.callback_added(item) for item in trait_list_event.removed: self.callback_removed(item) def callback_added(self, item): """ Wire up the different VTK callbacks. """ callback, type, button = item picker = getattr(self.scene.scene.picker, '%spicker' % type) self._active_pickers[type] = picker # Register the pick callback if not type in self._picker_callback_nbs: self._picker_callback_nbs[type] = \ picker.add_observer("EndPickEvent", self.on_pick) # Register the callbacks on the scene interactor if VTK_VERSION > 5: move_event = "RenderEvent" else: move_event = 'MouseMoveEvent' if not self._mouse_mvt_callback_nb: self._mouse_mvt_callback_nb = \ self.scene.scene.interactor.add_observer(move_event, self.on_mouse_move) if not button in self._mouse_press_callback_nbs: self._mouse_press_callback_nbs[button] = \ self.scene.scene.interactor.add_observer( '%sButtonPressEvent' % button, self.on_button_press) if VTK_VERSION > 5: release_event = "EndInteractionEvent" else: release_event = '%sButtonReleaseEvent' % button if not button in self._mouse_release_callback_nbs: self._mouse_release_callback_nbs[button] = \ self.scene.scene.interactor.add_observer( release_event, self.on_button_release) def callback_removed(self, item): """ Clean up the unecessary VTK callbacks. """ callback, type, button = item # If the picker is no longer needed, clean up its observers. if not [t for c, t, b in self.callbacks if t == type]: picker = self._active_pickers[type] picker.remove_observer(self._picker_callback_nbs[type]) del self._active_pickers[type] # If there are no longer callbacks on the button, clean up # the corresponding observers. if not [b for c, t, b in self.callbacks if b == button]: self.scene.scene.interactor.remove_observer( self._mouse_press_callback_nbs[button]) self.scene.scene.interactor.remove_observer( self._mouse_release_callback_nbs[button]) if len(self.callbacks) == 0 and self._mouse_mvt_callback_nb: self.scene.scene.interactor.remove_observer( self._mouse_mvt_callback_nb) self._mouse_mvt_callback_nb = 0 def clear_callbacks(self): while self.callbacks: self.callbacks.pop() #-------------------------------------------------------------------------- # Mouse movement dispatch mechanism #-------------------------------------------------------------------------- def on_button_press(self, vtk_picker, event): self._current_button = event[:-len('ButtonPressEvent')] self._mouse_no_mvt = 2 def on_mouse_move(self, vtk_picker, event): if self._mouse_no_mvt: self._mouse_no_mvt -= 1 def on_button_release(self, vtk_picker, event): """ If the mouse has not moved, pick with our pickers. """ if self._mouse_no_mvt: x, y = vtk_picker.GetEventPosition() for picker in self._active_pickers.values(): picker.pick((x, y, 0), self.scene.scene.renderer) self._mouse_no_mvt = 0 def on_pick(self, vtk_picker, event): """ Dispatch the pick to the callback associated with the corresponding mouse button. """ picker = tvtk.to_tvtk(vtk_picker) for event_type, event_picker in self._active_pickers.iteritems(): if picker is event_picker: for callback, type, button in self.callbacks: if (type == event_type and button == self._current_button): callback(picker) break #-------------------------------------------------------------------------- # Private methods #-------------------------------------------------------------------------- def __del__(self): self.clear_callbacks()
class BaseGraph(HasTraits): """ Defines a representation of a graph in Graphviz's dot language """ #-------------------------------------------------------------------------- # Trait definitions. #-------------------------------------------------------------------------- # Optional unique identifier. ID = id_trait # Synonym for ID. name = Alias("ID", desc="synonym for ID") # Used by InstanceEditor # Main graph nodes. nodes = List(Instance(Node)) # Map if node IDs to node objects. # id_node_map = Dict # Graph edges. edges = List(Instance(Edge)) # Separate layout regions. subgraphs = List(Instance("godot.subgraph.Subgraph")) # Clusters are encoded as subgraphs whose names have the prefix 'cluster'. clusters = List(Instance("godot.cluster.Cluster")) # Node from which new nodes are cloned. default_node = Instance(Node) # Edge from which new edges are cloned. default_edge = Instance(Edge) # Graph from which new subgraphs are cloned. default_graph = Instance(HasTraits) # Level of the graph in the subgraph hierarchy. # level = Int(0, desc="level in the subgraph hierarchy") # Padding to use for pretty string output. padding = Str(" ", desc="padding for pretty printing") # A dictionary containing the Graphviz executable names as keys # and their paths as values. See the trait initialiser. programs = Dict(desc="names and paths of Graphviz executables") # The Graphviz layout program program = Enum("dot", "circo", "neato", "twopi", "fdp", desc="layout program used by Graphviz") # Format for writing to file. format = Enum(FORMATS, desc="format used when writing to file") # Use Graphviz to arrange all graph components. arrange = Button("Arrange All") # Parses the Xdot attributes for all graph components. redraw = Button("Redraw Canvas") #-------------------------------------------------------------------------- # Enable trait definitions. #-------------------------------------------------------------------------- # Container of graph components. component = Instance(Container, desc="container of graph components.") # A view into a sub-region of the canvas. vp = Instance(Viewport, desc="a view of a sub-region of the canvas") #-------------------------------------------------------------------------- # Xdot trait definitions: #-------------------------------------------------------------------------- # For a given graph object, one will typically a draw directive before the # label directive. For example, for a node, one would first use the # commands in _draw_ followed by the commands in _ldraw_. _draw_ = Str(desc="xdot drawing directive") # Label draw directive. _ldraw_ = Str(desc="xdot label drawing directive") #-------------------------------------------------------------------------- # "object" interface: #-------------------------------------------------------------------------- # def __init__(self, **traits): # """ Initialises a new BaseGraph instance. # """ # super(BaseGraph, self).__init__(**traits) # # # Automatically creates all the methods enabling the saving # # of output in any of the supported formats. # for frmt in FORMATS: # self.__setattr__('save_'+frmt, # lambda flo, f=frmt, prog=self.program: \ # flo.write( self.create(format=f, prog=prog) )) # f = self.__dict__['save_'+frmt] # f.__doc__ = '''Refer to the docstring accompanying the 'create' # method for more information.''' def __len__(self): """ Return the order of the graph when requested by len(). @rtype: number @return: Size of the graph. """ return len(self.nodes) def __iter__(self): """ Return a iterator passing through all nodes in the graph. @rtype: iterator @return: Iterator passing through all nodes in the graph. """ for each in self.nodes: yield each def __getitem__(self, node): """ Return a iterator passing through all neighbours of the given node. @rtype: iterator @return: Iterator passing through all neighbours of the given node. """ for each_edge in self.edges: if (each_edge.tail_node == node) or (each_edge.head_node == node): yield each_edge def __str__(self): """ Returns a string representation of the graph in dot language. It will return the graph and all its subelements in string form. """ s = "" padding = self.padding if self.ID: s += "%s {\n" % self.ID else: s += "{\n" # Traits to be included in string output have 'graphviz' metadata. for trait_name, trait in self.traits(graphviz=True).iteritems(): # Get the value of the trait for comparison with the default. value = getattr(self, trait_name) # Only print attribute value pairs if not defaulted. # FIXME: Alias/Synced traits default to None. if (value != trait.default) and (trait.default is not None): if isinstance(value, basestring): # Add double quotes to the value if it is a string. valstr = '"%s"' % value else: valstr = str(value) s += "%s%s=%s;\n" % (padding, trait_name, valstr) def prepend_padding(s): return "\n".join([padding + line for line in s.splitlines()]) for node in self.nodes: s += "%s%s\n" % (padding, str(node)) for edge in self.edges: s += "%s%s\n" % (padding, str(edge)) for subgraph in self.subgraphs: s += prepend_padding(str(subgraph)) + "\n" for cluster in self.clusters: s += prepend_padding(str(cluster)) + "\n" s += "}" return s #-------------------------------------------------------------------------- # Trait initialisers: #-------------------------------------------------------------------------- def _default_node_default(self): """ Trait initialiser. """ return Node("default") def _default_edge_default(self): """ Trait initialiser. """ return Edge("tail", "head") def _default_graph_default(self): """ Trait initialiser. """ return godot.cluster.Cluster(ID="cluster_default") def _programs_default(self): """ Trait initaliser. """ progs = find_graphviz() if progs is None: logger.warning("GraphViz's executables not found") return {} else: return progs def _component_default(self): """ Trait initialiser. """ return Container(draw_axes=True, fit_window=False, auto_size=True) def _vp_default(self): """ Trait initialiser. """ vp = Viewport(component=self.component) vp.enable_zoom = True vp.view_position = [-5, -5] vp.tools.append(ViewportPanTool(vp)) return vp #-------------------------------------------------------------------------- # Public interface: #-------------------------------------------------------------------------- def save_dot(self, flo, prog=None): """ Writes a graph to a file. Given a file like object 'flo' it will truncate it and write a representation of the graph defined by the dot object and in the format specified. The format 'raw' is used to dump the string representation of the Dot object, without further processing. The output can be processed by any of graphviz tools, defined in 'prog', which defaults to 'dot'. """ flo.write(str(self)) def save_xdot(self, flo, prog=None): prog = self.program if prog is None else prog flo.write(self.create(prog, "xdot")) def save_png(self, flo, prog=None): prog = self.program if prog is None else prog flo.write(self.create(prog, "png")) @classmethod def load_dot(cls, flo): parser = godot.dot_data_parser.GodotDataParser() return parser.parse_dot_file(flo) @classmethod def load_xdot(cls, flo): parser = godot.dot_data_parser.GodotDataParser() return parser.parse_dot_file(flo) def create(self, prog=None, format=None): """ Creates and returns a representation of the graph using the Graphviz layout program given by 'prog', according to the given format. Writes the graph to a temporary dot file and processes it with the program given by 'prog' (which defaults to 'dot'), reading the output and returning it as a string if the operation is successful. On failure None is returned. """ prog = self.program if prog is None else prog format = self.format if format is None else format # Make a temporary file ... tmp_fd, tmp_name = tempfile.mkstemp() os.close(tmp_fd) # ... and save the graph to it. dot_fd = file(tmp_name, "w+b") self.save_dot(dot_fd) dot_fd.close() # Get the temporary file directory name. tmp_dir = os.path.dirname(tmp_name) # TODO: Shape image files (See PyDot). Important. # Process the file using the layout program, specifying the format. p = subprocess.Popen((self.programs[prog], '-T' + format, tmp_name), cwd=tmp_dir, stderr=subprocess.PIPE, stdout=subprocess.PIPE) stderr = p.stderr stdout = p.stdout # Make sense of the standard output form the process. stdout_output = list() while True: data = stdout.read() if not data: break stdout_output.append(data) stdout.close() if stdout_output: stdout_output = ''.join(stdout_output) # Similarly so for any standard error. if not stderr.closed: stderr_output = list() while True: data = stderr.read() if not data: break stderr_output.append(data) stderr.close() if stderr_output: stderr_output = ''.join(stderr_output) #pid, status = os.waitpid(p.pid, 0) status = p.wait() if status != 0: logger.error("Program terminated with status: %d. stderr " \ "follows: %s" % ( status, stderr_output ) ) elif stderr_output: logger.error("%s", stderr_output) # TODO: Remove shape image files from the temporary directory. # Remove the temporary file. os.unlink(tmp_name) return stdout_output @on_trait_change("arrange") def arrange_all(self): """ Sets for the _draw_ and _ldraw_ attributes for each of the graph sub-elements by processing the xdot format of the graph. """ import godot.dot_data_parser parser = godot.dot_data_parser.GodotDataParser() xdot_data = self.create(format="xdot") print "GRAPH DOT:\n", str(self) print "XDOT DATA:\n", xdot_data parser.dotparser.parseWithTabs() ndata = xdot_data.replace("\\\n", "") tokens = parser.dotparser.parseString(ndata)[0] parser.build_graph(graph=self, tokens=tokens[3]) def add_node(self, node_or_ID, **kwds): """ Adds a node to the graph. """ if not isinstance(node_or_ID, Node): nodeID = str(node_or_ID) if nodeID in self.nodes: node = self.nodes[self.nodes.index(nodeID)] else: if self.default_node is not None: node = self.default_node.clone_traits(copy="deep") node.ID = nodeID else: node = Node(nodeID) self.nodes.append(node) else: node = node_or_ID if node in self.nodes: node = self.nodes[self.nodes.index(node_or_ID)] else: self.nodes.append(node) node.set(**kwds) return node def delete_node(self, node_or_ID): """ Removes a node from the graph. """ if isinstance(node_or_ID, Node): # name = node_or_ID.ID node = node_or_ID else: # name = node_or_ID node = self.get_node(node_or_ID) if node is None: raise ValueError("Node %s does not exists" % node_or_ID) # try: # del self.nodes[name] # except: # raise ValueError("Node %s does not exists" % name) # self.nodes = [n for n in self.nodes if n.ID != name] # idx = self.nodes.index(name) # return self.nodes.pop(idx) self.nodes.remove(node) def get_node(self, ID): """ Returns the node with the given ID or None. """ for node in self.nodes: if node.ID == str(ID): return node return None def delete_edge(self, tail_node_or_ID, head_node_or_ID): """ Removes an edge from the graph. Returns the deleted edge or None. """ if isinstance(tail_node_or_ID, Node): tail_node = tail_node_or_ID else: tail_node = self.get_node(tail_node_or_ID) if isinstance(head_node_or_ID, Node): head_node = head_node_or_ID else: head_node = self.get_node(head_node_or_ID) if (tail_node is None) or (head_node is None): return None for i, edge in enumerate(self.edges): if (edge.tail_node == tail_node) and (edge.head_node == head_node): edge = self.edges.pop(i) return edge return None def add_edge(self, tail_node_or_ID, head_node_or_ID, **kwds): """ Adds an edge to the graph. """ tail_node = self.add_node(tail_node_or_ID) head_node = self.add_node(head_node_or_ID) # Only top level graphs are directed and/or strict. if "directed" in self.trait_names(): directed = self.directed else: directed = False if self.default_edge is not None: edge = self.default_edge.clone_traits(copy="deep") edge.tail_node = tail_node edge.head_node = head_node edge.conn = "->" if directed else "--" edge.set(**kwds) else: edge = Edge(tail_node, head_node, directed, **kwds) if "strict" in self.trait_names(): if not self.strict: self.edges.append(edge) else: self.edges.append(edge) # FIXME: Implement strict graphs. # raise NotImplementedError else: self.edges.append(edge) def add_subgraph(self, subgraph_or_ID): """ Adds a subgraph to the graph. """ if not isinstance(subgraph_or_ID, (godot.subgraph.Subgraph, godot.cluster.Cluster)): subgraphID = str(subgraph_or_ID) if subgraph_or_ID.startswith("cluster"): subgraph = godot.cluster.Cluster(ID=subgraphID) else: subgraph = godot.subgraph.Subgraph(ID=subgraphID) else: subgraph = subgraph_or_ID subgraph.default_node = self.default_node subgraph.default_edge = self.default_edge # subgraph.level = self.level + 1 # subgraph.padding += self.padding if isinstance(subgraph, godot.subgraph.Subgraph): self.subgraphs.append(subgraph) elif isinstance(subgraph, godot.cluster.Cluster): self.clusters.append(subgraph) else: raise return subgraph def add_cluster(self, cluster_or_ID): """ Adds a cluster to the graph. """ return self.add_subgraph(cluster_or_ID) #-------------------------------------------------------------------------- # "BaseGraph" interface: #-------------------------------------------------------------------------- def _program_changed(self, new): """ Handles the Graphviz layout program selection changing. """ progs = self.progs if not progs.has_key(prog): logger.warning('GraphViz\'s executable "%s" not found' % prog) if not os.path.exists( progs[prog] ) or not \ os.path.isfile( progs[prog] ): logger.warning("GraphViz's executable '%s' is not a " "file or doesn't exist" % progs[prog]) def _component_changed(self, new): """ Handles the graph canvas changing. """ self.vp.component = new # @on_trait_change("nodes,nodes_items") # def remove_duplicates(self, new): # """ Ensures node ID uniqueness. # """ # if isinstance(new, TraitListEvent): # old = event.removed # new = event.added # # set = {} # self.set( trait_change_notify = False, # nodes = [set.setdefault(e, e) for e in new if e not in set] ) @on_trait_change("nodes,nodes_items") def _set_node_lists(self, new): """ Maintains each edge's list of available nodes. """ for edge in self.edges: edge._nodes = self.nodes
class Preferences(HasTraits): """ The default implementation of a node in a preferences hierarchy. """ implements(IPreferences) #### 'IPreferences' interface ############################################# # The absolute path to this node from the root node (the empty string if # this node *is* the root node). path = Property(Str) # The parent node (None if this node *is* the root node). parent = Instance(IPreferences) # The name of the node relative to its parent (the empty string if this # node *is* the root node). name = Str #### 'Preferences' interface ############################################## # The default name of the file used to persist the preferences (if no # filename is passed in to the 'load' and 'save' methods, then this is # used instead). filename = Str #### Protected 'Preferences' interface #################################### # A lock to make access to the node thread-safe. # # fixme: There *should* be no need to declare this as a trait, but if we # don't then we have problems using nodes in the preferences manager UI. # It is something to do with 'cloning' the node for use in a 'modal' traits # UI... Hmmm... _lk = Any # The node's children. _children = Dict(Str, IPreferences) # The node's preferences. _preferences = Dict(Str, Any) # Listeners for changes to the node's preferences. # # The callable must take 4 arguments, e.g:: # # listener(node, key, old, new) _preferences_listeners = List(Callable) ########################################################################### # 'object' interface. ########################################################################### def __init__(self, **traits): """ Constructor. """ # A lock to make access to the '_children', '_preferences' and # '_preferences_listeners' traits thread-safe. self._lk = threading.Lock() # Base class constructor. super(Preferences, self).__init__(**traits) # If a filename has been specified then load the preferences from it. if len(self.filename) > 0: self.load() return ########################################################################### # 'IPreferences' interface. ########################################################################### #### Trait properties ##################################################### def _get_path(self): """ Property getter. """ names = [] node = self while node.parent is not None: names.append(node.name) node = node.parent names.reverse() return '.'.join(names) #### Methods ############################################################## #### Methods where 'path' refers to a preference #### def get(self, path, default=None, inherit=False): """ Get the value of the preference at the specified path. """ if len(path) == 0: raise ValueError('empty path') components = path.split('.') # If there is only one component in the path then the operation takes # place in this node. if len(components) == 1: value = self._get(path, Undefined) # Otherwise, find the next node and pass the rest of the path to that. else: node = self._get_child(components[0]) if node is not None: value = node.get('.'.join(components[1:]), Undefined) else: value = Undefined # If inherited values are allowed then try those as well. # # e.g. 'acme.ui.widget.bgcolor' # 'acme.ui.bgcolor' # 'acme.bgcolor' # 'bgcolor' while inherit and value is Undefined and len(components) > 1: # Remove the penultimate component... # # e.g. 'acme.ui.widget.bgcolor' -> 'acme.ui.bgcolor' del components[-2] # ... and try that. value = self.get('.'.join(components), default=Undefined) if value is Undefined: value = default return value def remove(self, path): """ Remove the preference at the specified path. """ if len(path) == 0: raise ValueError('empty path') components = path.split('.') # If there is only one component in the path then the operation takes # place in this node. if len(components) == 1: self._remove(path) # Otherwise, find the next node and pass the rest of the path to that. else: node = self._get_child(components[0]) if node is not None: node.remove('.'.join(components[1:])) return def set(self, path, value): """ Set the value of the preference at the specified path. """ if len(path) == 0: raise ValueError('empty path') components = path.split('.') # If there is only one component in the path then the operation takes # place in this node. if len(components) == 1: self._set(path, value) # Otherwise, find the next node (creating it if it doesn't exist) # and pass the rest of the path to that. else: node = self._node(components[0]) node.set('.'.join(components[1:]), value) return #### Methods where 'path' refers to a node #### def clear(self, path=''): """ Remove all preferences from the node at the specified path. """ # If the path is empty then the operation takes place in this node. if len(path) == 0: self._clear() # Otherwise, find the next node and pass the rest of the path to that. else: components = path.split('.') node = self._get_child(components[0]) if node is not None: node.clear('.'.join(components[1:])) return def keys(self, path=''): """ Return the preference keys of the node at the specified path. """ # If the path is empty then the operation takes place in this node. if len(path) == 0: keys = self._keys() # Otherwise, find the next node and pass the rest of the path to that. else: components = path.split('.') node = self._get_child(components[0]) if node is not None: keys = node.keys('.'.join(components[1:])) else: keys = [] return keys def node(self, path=''): """ Return the node at the specified path. """ # If the path is empty then the operation takes place in this node. if len(path) == 0: node = self # Otherwise, find the next node and pass the rest of the path to that. else: components = path.split('.') node = self._node(components[0]) node = node.node('.'.join(components[1:])) return node def node_exists(self, path=''): """ Return True if the node at the specified path exists. """ # If the path is empty then the operation takes place in this node. if len(path) == 0: exists = True # Otherwise, find the next node and pass the rest of the path to that. else: components = path.split('.') node = self._get_child(components[0]) if node is not None: exists = node.node_exists('.'.join(components[1:])) else: exists = False return exists def node_names(self, path=''): """ Return the names of the children of the node at the specified path. """ # If the path is empty then the operation takes place in this node. if len(path) == 0: names = self._node_names() # Otherwise, find the next node and pass the rest of the path to that. else: components = path.split('.') node = self._get_child(components[0]) if node is not None: names = node.node_names('.'.join(components[1:])) else: names = [] return names #### Persistence methods #### def flush(self): """ Force any changes in the node to the backing store. This includes any changes to the node's descendants. """ self.save() return ########################################################################### # 'Preferences' interface. ########################################################################### #### Listener methods #### def add_preferences_listener(self, listener, path=''): """ Add a listener for changes to a node's preferences. """ # If the path is empty then the operation takes place in this node. if len(path) == 0: names = self._add_preferences_listener(listener) # Otherwise, find the next node and pass the rest of the path to that. else: components = path.split('.') node = self._node(components[0]) node.add_preferences_listener(listener, '.'.join(components[1:])) return def remove_preferences_listener(self, listener, path=''): """ Remove a listener for changes to a node's preferences. """ # If the path is empty then the operation takes place in this node. if len(path) == 0: names = self._remove_preferences_listener(listener) # Otherwise, find the next node and pass the rest of the path to that. else: components = path.split('.') node = self._node(components[0]) node.remove_preferences_listener(listener, '.'.join(components[1:])) return #### Persistence methods #### def load(self, file_or_filename=None): """ Load preferences from a file. This is a *merge* operation i.e. the contents of the file are added to the node. This implementation uses 'ConfigObj' files. """ if file_or_filename is None: file_or_filename = self.filename logger.debug('loading preferences from <%s>', file_or_filename) # Do the import here so that we don't make 'ConfigObj' a requirement # if preferences aren't ever persisted (or a derived class chooses to # use a different persistence mechanism). from configobj import ConfigObj config_obj = ConfigObj(file_or_filename) # 'name' is the section name, 'value' is a dictionary containing the # name/value pairs in the section (the actual preferences ;^). for name, value in config_obj.items(): # Create/get the node from the section name. components = name.split('.') node = self for component in components: node = node._node(component) # Add the contents of the section to the node. self._add_dictionary_to_node(node, value) return def save(self, file_or_filename=None): """ Save the node's preferences to a file. This implementation uses 'ConfigObj' files. """ if file_or_filename is None: file_or_filename = self.filename # If no file or filename is specified then don't save the preferences! if len(file_or_filename) > 0: # Do the import here so that we don't make 'ConfigObj' a # requirement if preferences aren't ever persisted (or a derived # class chooses to use a different persistence mechanism). from configobj import ConfigObj logger.debug('saving preferences to <%s>', file_or_filename) config_obj = ConfigObj(file_or_filename) self._add_node_to_dictionary(self, config_obj) config_obj.write() return ########################################################################### # Protected 'Preferences' interface. # # These are the only methods that should access the protected '_children' # and '_preferences' traits. This helps make it easy to subclass this class # to create other implementations (all the subclass has to do is to # implement these protected methods). # ########################################################################### def _add_dictionary_to_node(self, node, dictionary): """ Add the contents of a dictionary to a node's preferences. """ self._lk.acquire() node._preferences.update(dictionary) self._lk.release() return def _add_node_to_dictionary(self, node, dictionary): """ Add a node's preferences to a dictionary. """ # This method never manipulates the '_preferences' trait directly. # Instead it does eveything via the other protected methods and hence # doesn't need to grab the lock. if len(node._keys()) > 0: dictionary[node.path] = {} for key in node._keys(): dictionary[node.path][key] = node._get(key) for name in node._node_names(): self._add_node_to_dictionary(node._get_child(name), dictionary) return def _add_preferences_listener(self, listener): """ Add a listener for changes to thisnode's preferences. """ self._lk.acquire() self._preferences_listeners.append(listener) self._lk.release() return def _clear(self): """ Remove all preferences from this node. """ self._lk.acquire() self._preferences.clear() self._lk.release() return def _create_child(self, name): """ Create a child of this node with the specified name. """ self._lk.acquire() child = self._children[name] = Preferences(name=name, parent=self) self._lk.release() return child def _get(self, key, default=None): """ Get the value of a preference in this node. """ self._lk.acquire() value = self._preferences.get(key, default) self._lk.release() return value def _get_child(self, name): """ Return the child of this node with the specified name. Return None if no such child exists. """ self._lk.acquire() child = self._children.get(name) self._lk.release() return child def _keys(self): """ Return the preference keys of this node. """ self._lk.acquire() keys = self._preferences.keys() self._lk.release() return keys def _node(self, name): """ Return the child of this node with the specified name. Create the child node if it does not exist. """ node = self._get_child(name) if node is None: node = self._create_child(name) return node def _node_names(self): """ Return the names of the children of this node. """ self._lk.acquire() node_names = self._children.keys() self._lk.release() return node_names def _remove(self, name): """ Remove a preference value from this node. """ self._lk.acquire() if name in self._preferences: del self._preferences[name] self._lk.release() return def _remove_preferences_listener(self, listener): """ Remove a listener for changes to the node's preferences. """ self._lk.acquire() if listener in self._preferences_listeners: self._preferences_listeners.remove(listener) self._lk.release() return def _set(self, key, value): """ Set the value of a preference in this node. """ # Preferences are *always* stored as strings. value = str(value) self._lk.acquire() old = self._preferences.get(key) self._preferences[key] = value # If the value is unchanged then don't call the listeners! if old == value: listeners = [] else: listeners = self._preferences_listeners[:] self._lk.release() for listener in listeners: listener(self, key, old, value) return ########################################################################### # Debugging interface. ########################################################################### def dump(self, indent=''): """ Dump the preferences hierarchy to stdout. """ if indent == '': print print indent, 'Node(%s)' % self.name, self._preferences indent += ' ' for child in self._children.values(): child.dump(indent) return
class TemplateDataContext(HasPrivateTraits): """ A concrete implementation of the ITemplateDataContext interface intended to be used for creating the *output_data_context* value of an **ITemplateDataNameItem** implementation (although they are not required to use it). """ implements(ITemplateDataContext) #-- 'ITemplateDataContext' Interface Traits -------------------------------- # The path to this data context (does not include the 'data_context_name'): data_context_path = Str # The name of the data context: data_context_name = Str # A list of the names of the data values in this context: data_context_values = Property # List( Str ) # The list of the names of the sub-contexts of this context: data_contexts = Property # List( Str ) #-- Public Traits --------------------------------------------------------- # The data context values dictionary: values = Dict(Str, Any) # The data contexts dictionary: contexts = Dict(Str, ITemplateDataContext) #-- 'ITemplateDataContext' Property Implementations ------------------------ @cached_property def _get_data_context_values(self): values = self.values.keys() values.sort() return values @cached_property def _get_data_contexts(self): contexts = self.contexts.keys() contexts.sort() return contexts #-- 'ITemplateDataContext' Interface Implementation ------------------------ def get_data_context_value(self, name): """ Returns the data value with the specified *name*. Raises a **ITemplateDataContextError** if *name* is not defined as a data value in the context. Parameters ---------- name : A string specifying the name of the context data value to be returned. Returns ------- The data value associated with *name* in the context. The type of the data is application dependent. Raises **ITemplateDataContextError** if *name* is not associated with a data value in the context. """ try: return self.values[name] except: raise ITemplateDataContextError("Value '%s' not found." % name) def get_data_context(self, name): """ Returns the **ITemplateDataContext** value associated with the specified *name*. Raises **ITemplateDataContextError** if *name* is not defined as a data context in the context. Parameters ---------- name : A string specifying the name of the data context to be returned. Returns ------- The **ITemplateDataContext** associated with *name* in the context. Raises **ITemplateDataContextError** if *name* is not associated with a data context in the context. """ try: return self.context[name] except: raise ITemplateDataContextError("Context '%s' not found." % name)
class MultiFitGui(HasTraits): """ data should be c x N where c is the number of data columns/axes and N is the number of points """ doplot3d = Bool(False) show3d = Button('Show 3D Plot') replot3d = Button('Replot 3D') scalefactor3d = Float(0) do3dscale = Bool(False) nmodel3d = Int(1024) usecolor3d = Bool(False) color3d = Color((0,0,0)) scene3d = Instance(MlabSceneModel,()) plot3daxes = Tuple(('x','y','z')) data = Array(shape=(None,None)) weights = Array(shape=(None,)) curveaxes = List(Tuple(Int,Int)) axisnames = Dict(Int,Str) invaxisnames = Property(Dict,depends_on='axisnames') fgs = List(Instance(FitGui)) traits_view = View(VGroup(Item('fgs',editor=ListEditor(use_notebook=True,page_name='.plotname'),style='custom',show_label=False), Item('show3d',show_label=False)), resizable=True,height=900,buttons=['OK','Cancel'],title='Multiple Model Data Fitters') plot3d_view = View(VGroup(Item('scene3d',editor=SceneEditor(scene_class=MayaviScene),show_label=False,resizable=True), Item('plot3daxes',editor=TupleEditor(cols=3,labels=['x','y','z']),label='Axes'), HGroup(Item('do3dscale',label='Scale by weight?'), Item('scalefactor3d',label='Point scale'), Item('nmodel3d',label='Nmodel')), HGroup(Item('usecolor3d',label='Use color?'),Item('color3d',label='Relation Color',enabled_when='usecolor3d')), Item('replot3d',show_label=False),springy=True), resizable=True,height=800,width=800,title='Multiple Model3D Plot') def __init__(self,data,names=None,models=None,weights=None,dofits=True,**traits): """ :param data: The data arrays :type data: sequence of c equal-length arrays (length N) :param names: Names :type names: sequence of strings, length c :param models: The models to fit for each pair either as strings or :class:`astroypsics.models.ParametricModel` objects. :type models: sequence of models, length c-1 :param weights: the weights for each point or None for no weights :type weights: array-like of size N or None :param dofits: If True, the data will be fit to the models when the object is created, otherwise the models will be passed in as-is (or as created). :type dofits: bool extra keyword arguments get passed in as new traits (r[finmask],m[finmask],l[finmask]),names='rh,Mh,Lh',weights=w[finmask],models=models,dofits=False) """ super(MultiFitGui,self).__init__(**traits) self._lastcurveaxes = None data = np.array(data,copy=False) if weights is None: self.weights = np.ones(data.shape[1]) else: self.weights = np.array(weights) self.data = data if data.shape[0] < 2: raise ValueError('Must have at least 2 columns') if isinstance(names,basestring): names = names.split(',') if names is None: if len(data) == 2: self.axisnames = {0:'x',1:'y'} elif len(data) == 3: self.axisnames = {0:'x',1:'y',2:'z'} else: self.axisnames = dict((i,str(i)) for i in data) elif len(names) == len(data): self.axisnames = dict([t for t in enumerate(names)]) else: raise ValueError("names don't match data") #default to using 0th axis as parametric self.curveaxes = [(0,i) for i in range(len(data))[1:]] if models is not None: if len(models) != len(data)-1: raise ValueError("models don't match data") for i,m in enumerate(models): fg = self.fgs[i] newtmodel = TraitedModel(m) if dofits: fg.tmodel = newtmodel fg.fitmodel = True #should happen automatically, but this makes sure else: oldpard = newtmodel.model.pardict fg.tmodel = newtmodel fg.tmodel .model.pardict = oldpard if dofits: fg.fitmodel = True def _data_changed(self): self.curveaxes = [(0,i) for i in range(len(self.data))[1:]] def _axisnames_changed(self): for ax,fg in zip(self.curveaxes,self.fgs): fg.plot.x_axis.title = self.axisnames[ax[0]] if ax[0] in self.axisnames else '' fg.plot.y_axis.title = self.axisnames[ax[1]] if ax[1] in self.axisnames else '' self.plot3daxes = (self.axisnames[0],self.axisnames[1],self.axisnames[2] if len(self.axisnames) > 2 else self.axisnames[1]) @on_trait_change('curveaxes[]') def _curveaxes_update(self,names,old,new): ax=[] for t in self.curveaxes: ax.append(t[0]) ax.append(t[1]) if set(ax) != set(range(len(self.data))): self.curveaxes = self._lastcurveaxes return #TOOD:check for recursion if self._lastcurveaxes is None: self.fgs = [FitGui(self.data[t[0]],self.data[t[1]],weights=self.weights) for t in self.curveaxes] for ax,fg in zip(self.curveaxes,self.fgs): fg.plot.x_axis.title = self.axisnames[ax[0]] if ax[0] in self.axisnames else '' fg.plot.y_axis.title = self.axisnames[ax[1]] if ax[1] in self.axisnames else '' else: for i,t in enumerate(self.curveaxes): if self._lastcurveaxes[i] != t: self.fgs[i] = fg = FitGui(self.data[t[0]],self.data[t[1]],weights=self.weights) ax = self.curveaxes[i] fg.plot.x_axis.title = self.axisnames[ax[0]] if ax[0] in self.axisnames else '' fg.plot.y_axis.title = self.axisnames[ax[1]] if ax[1] in self.axisnames else '' self._lastcurveaxes = self.curveaxes def _show3d_fired(self): self.edit_traits(view='plot3d_view') self.doplot3d = True self.replot3d = True def _plot3daxes_changed(self): self.replot3d = True @on_trait_change('weights',post_init=True) def weightsChanged(self): for fg in self.fgs: if fg.weighttype != 'custom': fg.weighttype = 'custom' fg.weights = self.weights @on_trait_change('data','fgs','replot3d','weights') def _do_3d(self): if self.doplot3d: M = self.scene3d.mlab try: xi = self.invaxisnames[self.plot3daxes[0]] yi = self.invaxisnames[self.plot3daxes[1]] zi = self.invaxisnames[self.plot3daxes[2]] x,y,z = self.data[xi],self.data[yi],self.data[zi] w = self.weights M.clf() if self.scalefactor3d == 0: sf = x.max()-x.min() sf *= y.max()-y.min() sf *= z.max()-z.min() sf = sf/len(x)/5 self.scalefactor3d = sf else: sf = self.scalefactor3d glyph = M.points3d(x,y,z,w,scale_factor=sf) glyph.glyph.scale_mode = 0 if self.do3dscale else 1 M.axes(xlabel=self.plot3daxes[0],ylabel=self.plot3daxes[1],zlabel=self.plot3daxes[2]) try: xs = np.linspace(np.min(x),np.max(x),self.nmodel3d) #find sequence of models to go from x to y and z ymods,zmods = [],[] for curri,mods in zip((yi,zi),(ymods,zmods)): while curri != xi: for i,(i1,i2) in enumerate(self.curveaxes): if curri==i2: curri = i1 mods.insert(0,self.fgs[i].tmodel.model) break else: raise KeyError ys = xs for m in ymods: ys = m(ys) zs = xs for m in zmods: zs = m(zs) if self.usecolor3d: c = (self.color3d[0]/255,self.color3d[1]/255,self.color3d[2]/255) M.plot3d(xs,ys,zs,color=c) else: M.plot3d(xs,ys,zs,np.arange(len(xs))) except (KeyError,TypeError): M.text(0.5,0.75,'Underivable relation') except KeyError: M.clf() M.text(0.25,0.25,'Data problem') @cached_property def _get_invaxisnames(self): d={} for k,v in self.axisnames.iteritems(): d[v] = k return d
class tcGeneric(HasTraits): name = String start_ts = CArray end_ts = CArray types = CArray has_comments = Bool(True) total_time = Property(Int) max_types = Property(Int) max_latency = Property(Int) max_latency_ts = Property(CArray) overview_ts_cache = Dict({}) @cached_property def _get_total_time(self): return sum(self.end_ts - self.start_ts) @cached_property def _get_max_types(self): return amax(self.types) @cached_property def _get_max_latency(self): return -1 def get_partial_tables(self, start, end): low_i = searchsorted(self.end_ts, start) high_i = searchsorted(self.start_ts, end) ends = self.end_ts[low_i:high_i].copy() starts = self.start_ts[low_i:high_i].copy() if len(starts) == 0: return np.array([]), np.array([]), [] # take care of activities crossing the selection if starts[0] < start: starts[0] = start if ends[-1] > end: ends[-1] = end types = self.types[low_i:high_i] return starts, ends, types def get_overview_ts(self, threshold): """merge events so that there never are two events in the same "threshold" microsecond """ if threshold in self.overview_ts_cache: return self.overview_ts_cache[threshold] # we recursively use the lower threshold caches # this allows to pre-compute the whole cache more efficiently if threshold > 4: origin_start_ts, origin_end_ts = self.get_overview_ts(threshold / 2) else: origin_start_ts, origin_end_ts = self.start_ts, self.end_ts # only calculate overview if it worth. if len(origin_start_ts) < 500: overview = (origin_start_ts, origin_end_ts) self.overview_ts_cache[threshold] = overview return overview # assume at least one event start_ts = [] end_ts = [] # start is the first start of the merge list start = origin_start_ts[0] i = 1 while i < len(origin_start_ts): if origin_start_ts[i] > origin_start_ts[i - 1] + threshold: start_ts.append(start) end_ts.append(origin_end_ts[i - 1]) start = origin_start_ts[i] i += 1 start_ts.append(start) end_ts.append(origin_end_ts[i - 1]) overview = (numpy.array(start_ts), numpy.array(end_ts)) self.overview_ts_cache[threshold] = overview return overview # UI traits default_bg_color = Property(ColorTrait) bg_color = Property(ColorTrait) @cached_property def _get_bg_color(self): return colors.get_traits_color_by_name("idle_bg")
from enthought.traits.api \ import Trait, HasStrictTraits, List, Dict, Str, Int, Any from enthought.traits.trait_base \ import enumerate from view_element \ import ViewElement #------------------------------------------------------------------------------- # Trait definitions: #------------------------------------------------------------------------------- # Trait for contents of a ViewElements object content_trait = Dict( str, ViewElement ) #------------------------------------------------------------------------------- # 'ViewElements' class: #------------------------------------------------------------------------------- class ViewElements ( HasStrictTraits ): """ Defines a hierarchical name space of related ViewElement objects. """ #--------------------------------------------------------------------------- # Trait definitions: #--------------------------------------------------------------------------- # Dictionary containing the named ViewElement items content = content_trait
class BuiltinImage(Source): # The version of this class. Used for persistence. __version__ = 0 # Flag to set the image data type. source = Enum('ellipsoid','gaussian','grid','mandelbrot','noise', 'sinusoid','rt_analytic', desc='which image data source to be used') # Define the trait 'data_source' whose value must be an instance of # type ImageAlgorithm data_source = Instance(tvtk.ImageAlgorithm, allow_none=False, record=True) # Information about what this object can produce. output_info = PipelineInfo(datasets=['image_data'], attribute_types=['any'], attributes=['any']) # Create the UI for the traits. view = View(Group(Item(name='source'), Item(name='data_source', style='custom', resizable=True), label='Image Source', show_labels=False), resizable=True) ######################################## # Private traits. # A dictionary that maps the source names to instances of the # image data objects. _source_dict = Dict(Str, Instance(tvtk.ImageAlgorithm, allow_none=False)) ###################################################################### # `object` interface ###################################################################### def __init__(self, **traits): # Call parent class' init. super(BuiltinImage, self).__init__(**traits) # Initialize the source to the default mode's instance from # the dictionary if needed. if 'source' not in traits: self._source_changed(self.source) def __set_pure_state__(self, state): self.source = state.source super(BuiltinImage, self).__set_pure_state__(state) ###################################################################### # Non-public methods. ###################################################################### def _source_changed(self, value): """This method is invoked (automatically) when the `function` trait is changed. """ self.data_source = self._source_dict[self.source] def _data_source_changed(self, old, new): """This method is invoked (automatically) when the image data source is changed .""" self.outputs = [self.data_source.output] if old is not None: old.on_trait_change(self.render, remove=True) new.on_trait_change(self.render) def __source_dict_default(self): """The default _source_dict trait.""" sd = { 'ellipsoid':tvtk.ImageEllipsoidSource(), 'gaussian':tvtk.ImageGaussianSource(), 'grid':tvtk.ImageGridSource(), 'mandelbrot':tvtk.ImageMandelbrotSource(), 'noise':tvtk.ImageNoiseSource(), 'sinusoid':tvtk.ImageSinusoidSource(), } if hasattr(tvtk, 'RTAnalyticSource'): sd['rt_analytic'] = tvtk.RTAnalyticSource() else: sd['rt_analytic'] = tvtk.ImageNoiseSource() return sd
class DatasetManager(HasTraits): # The TVTK dataset we manage. dataset = Instance(tvtk.DataSet) # Our output, this is the dataset modified by us with different # active arrays. output = Property(Instance(tvtk.DataSet)) # The point scalars for the dataset. You may manipulate the arrays # in-place. However adding new keys in this dict will not set the # data in the `dataset` for that you must explicitly call # `add_array`. point_scalars = Dict(Str, Array) # Point vectors. point_vectors = Dict(Str, Array) # Point tensors. point_tensors = Dict(Str, Array) # The cell scalars for the dataset. cell_scalars = Dict(Str, Array) cell_vectors = Dict(Str, Array) cell_tensors = Dict(Str, Array) # This filter allows us to change the attributes of the data # object and will ensure that the pipeline is properly taken care # of. Directly setting the array in the VTK object will not do # this. _assign_attribute = Instance(tvtk.AssignAttribute, args=(), allow_none=False) ###################################################################### # Public interface. ###################################################################### def add_array(self, array, name, category='point'): """ Add an array to the dataset to specified category ('point' or 'cell'). """ assert len(array.shape) <= 2, "Only 2D arrays can be added." data = getattr(self.dataset, '%s_data' % category) if len(array.shape) == 2: assert array.shape[1] in [1, 3, 4, 9], \ "Only Nxm arrays where (m in [1,3,4,9]) are supported" va = tvtk.to_tvtk(array2vtk(array)) va.name = name data.add_array(va) mapping = {1: 'scalars', 3: 'vectors', 4: 'scalars', 9: 'tensors'} dict = getattr(self, '%s_%s' % (category, mapping[array.shape[1]])) dict[name] = array else: va = tvtk.to_tvtk(array2vtk(array)) va.name = name data.add_array(va) dict = getattr(self, '%s_scalars' % (category)) dict[name] = array def remove_array(self, name, category='point'): """Remove an array by its name and optional category (point and cell). Returns the removed array. """ type = self._find_array(name, category) data = getattr(self.dataset, '%s_data' % category) data.remove_array(name) d = getattr(self, '%s_%s' % (category, type)) return d.pop(name) def rename_array(self, name1, name2, category='point'): """Rename a particular array from `name1` to `name2`. """ type = self._find_array(name1, category) data = getattr(self.dataset, '%s_data' % category) arr = data.get_array(name1) arr.name = name2 d = getattr(self, '%s_%s' % (category, type)) d[name2] = d.pop(name1) def activate(self, name, category='point'): """Make the specified array the active one. """ type = self._find_array(name, category) self._activate_data_array(type, category, name) def update(self): """Update the dataset when the arrays are changed. """ self.dataset.modified() self._assign_attribute.update() ###################################################################### # Non-public interface. ###################################################################### def _dataset_changed(self, value): self._setup_data() self._assign_attribute.input = value def _get_output(self): return self._assign_attribute.output def _setup_data(self): """Updates the arrays from what is available in the input data. """ input = self.dataset pnt_attr, cell_attr = get_all_attributes(input) self._setup_data_arrays(cell_attr, 'cell') self._setup_data_arrays(pnt_attr, 'point') def _setup_data_arrays(self, attributes, d_type): """Given the dict of the attributes from the `get_all_attributes` function and the data type (point/cell) data this will setup the object and the data. """ attrs = ['scalars', 'vectors', 'tensors'] aa = self._assign_attribute input = self.dataset data = getattr(input, '%s_data' % d_type) for attr in attrs: values = attributes[attr] # Get the arrays from VTK, create numpy arrays and setup our # traits. arrays = {} for name in values: va = data.get_array(name) npa = va.to_array() # Now test if changes to the numpy array are reflected # in the VTK array, if they are we are set, else we # have to set the VTK array back to the numpy array. if len(npa.shape) > 1: old = npa[0, 0] npa[0][0] = old - 1 if abs(va[0][0] - npa[0, 0]) > 1e-8: va.from_array(npa) npa[0][0] = old else: old = npa[0] npa[0] = old - 1 if abs(va[0] - npa[0]) > 1e-8: va.from_array(npa) npa[0] = old arrays[name] = npa setattr(self, '%s_%s' % (d_type, attr), arrays) def _activate_data_array(self, data_type, category, name): """Activate (or deactivate) a particular array. Given the nature of the data (scalars, vectors etc.) and the type of data (cell or points) it activates the array given by its name. Parameters: ----------- data_type: one of 'scalars', 'vectors', 'tensors' category: one of 'cell', 'point'. name: string of array name to activate. """ input = self.dataset data = None data = getattr(input, category + '_data') method = getattr(data, 'set_active_%s' % data_type) if len(name) == 0: # If the value is empty then we deactivate that attribute. method(None) else: aa = self._assign_attribute method(name) aa.assign(name, data_type.upper(), category.upper() + '_DATA') aa.update() def _find_array(self, name, category='point'): """Return information on which kind of attribute contains the specified named array in a particular category.""" types = ['scalars', 'vectors', 'tensors'] for type in types: attr = '%s_%s' % (category, type) d = getattr(self, attr) if name in d.keys(): return type raise KeyError('No %s array named %s available in dataset' % (category, name))