Exemple #1
0
class PyFSContext(DirContext, Referenceable):
    """ A Python File System context.

    This context represents a directory on a local file system.

    """

    # The name of the 'special' file in which we store object attributes.
    ATTRIBUTES_FILE = ATTRIBUTES_FILE

    # Environment property keys.
    FILTERS = FILTERS
    OBJECT_SERIALIZERS = OBJECT_SERIALIZERS

    #### 'Context' interface ##################################################

    # The naming environment in effect for this context.
    environment = Dict(ENVIRONMENT)

    # The name of the context within its own namespace.
    namespace_name = Property(Str)

    #### 'PyFSContext' interface ##############################################

    # The name of the context (the last component of the path).
    name = Str

    # The path name of the directory on the local file system.
    path = Str

    #### 'Referenceable' interface ############################################

    # The object's reference suitable for binding in a naming context.
    reference = Property(Instance(Reference))

    #### Private interface ####################################################

    # A mapping from bound name to the name of the corresponding file or
    # directory on the file system.
    _name_to_filename_map = Dict  #(Str, Str)

    # The attributes of every object in the context.  The attributes for the
    # context itself have the empty string as the key.
    #
    # {str name : dict attributes}
    #
    # fixme: Don't use 'Dict' here as it causes problems when pickling because
    # trait dicts have a reference back to the parent object (hence we end up
    # pickling all kinds of things that we don't need or want to!).
    _attributes = Any

    ###########################################################################
    # 'object' interface.
    ###########################################################################

    def __init__(self, **traits):
        """ Creates a new context. """

        # Base class constructor.
        super(PyFSContext, self).__init__(**traits)

        # We cache each object as it is looked up so that all accesses to a
        # serialized Python object return a reference to exactly the same one.
        self._cache = {}

        return

    ###########################################################################
    # 'PyFSContext' interface.
    ###########################################################################

    #### Properties ###########################################################

    def _get_namespace_name(self):
        """ Returns the name of the context within its own namespace. """

        # fixme: clean this up with an initial context API!
        if 'root' in self.environment:
            root = self.environment['root']

            namespace_name = self.path[len(root) + 1:]

        else:
            namespace_name = self.path

        # fixme: This is a bit dodgy 'cos we actually return a name that can
        # be looked up, and not the file system name...
        namespace_name = '/'.join(namespace_name.split(os.path.sep))

        return namespace_name

    #### methods ##############################################################

    def refresh(self):
        """ Refresh the context to reflect changes in the file system. """

        # fixme: This needs more work 'cos if we refresh a context then we
        # will load new copies of serialized Python objects!

        # This causes the initializer to run again the next time the trait is
        # accessed.
        self.reset_traits(['_name_to_filename_map'])

        # Clear out the cache.
        self._cache = {}

        # fixme: This is a bit hacky since the context in the binding may
        # not be None!
        self.context_changed = NamingEvent(
            new_binding=Binding(name=self.name, obj=self, context=None))

        return

    ###########################################################################
    # 'Referenceable' interface.
    ###########################################################################

    #### Properties ###########################################################

    def _get_reference(self):
        """ Returns a reference to this object suitable for binding. """

        abspath = os.path.abspath(self.path)

        reference = Reference(
            class_name=self.__class__.__name__,
            addresses=[Address(type='pyfs_context', content=abspath)])

        return reference

    ###########################################################################
    # Protected 'Context' interface.
    ###########################################################################

    def _is_bound(self, name):
        """ Is a name bound in this context? """

        return name in self._name_to_filename_map

    def _lookup(self, name):
        """ Looks up a name in this context. """

        if name in self._cache:
            obj = self._cache[name]

        else:
            # Get the full path to the file.
            path = join(self.path, self._name_to_filename_map[name])

            # If the file contains a serialized Python object then load it.
            for serializer in self._get_object_serializers():
                if serializer.can_load(path):
                    try:
                        state = serializer.load(path)

                    # If the load fails then we create a generic file resource
                    # (the idea being that it might be useful to have access to
                    # the file to see what went wrong).
                    except:
                        state = File(path)
                        logger.exception('Error loading resource at %s' % path)

                    break

            # Otherwise, it must just be a file or folder.
            else:
                # Directories are contexts.
                if os.path.isdir(path):
                    state = self._context_factory(name, path)

                # Files are just files!
                elif os.path.isfile(path):
                    state = File(path)

                else:
                    raise ValueError('unrecognized file for %s' % name)

            # Get the actual object from the naming manager.
            obj = naming_manager.get_object_instance(state, name, self)

            # Update the cache.
            self._cache[name] = obj

        return obj

    def _bind(self, name, obj):
        """ Binds a name to an object in this context. """

        # Get the actual state to bind from the naming manager.
        state = naming_manager.get_state_to_bind(obj, name, self)

        # If the object is actually an abstract file then we don't have to
        # do anything.
        if isinstance(state, File):
            if not state.exists:
                state.create_file()

            filename = name

        # Otherwise we are binding an arbitrary Python object, so find a
        # serializer for it.
        else:
            for serializer in self._get_object_serializers():
                if serializer.can_save(obj):
                    path = serializer.save(join(self.path, name), obj)
                    filename = os.path.basename(path)
                    break

            else:
                raise ValueError('cannot serialize object %s' % name)

        # Update the name to filename map.
        self._name_to_filename_map[name] = filename

        # Update the cache.
        self._cache[name] = obj

        return state

    def _rebind(self, name, obj):
        """ Rebinds a name to an object in this context. """

        # We unbind first to make sure that the old file gets removed (this
        # is handy if the object that we are rebinding has a different
        # serializer than the current one).
        #self._unbind(name)

        self._bind(name, obj)

        return

    def _unbind(self, name):
        """ Unbinds a name from this context. """

        # Get the full path to the file.
        path = join(self.path, self._name_to_filename_map[name])

        # Remove it!
        f = File(path)
        f.delete()

        # Update the name to filename map.
        del self._name_to_filename_map[name]

        # Update the cache.
        if name in self._cache:
            del self._cache[name]

        # Remove any attributes.
        if name in self._attributes:
            del self._attributes[name]
            self._save_attributes()

        return

    def _rename(self, old_name, new_name):
        """ Renames an object in this context. """

        # Get the old filename.
        old_filename = self._name_to_filename_map[old_name]
        old_file = File(join(self.path, old_filename))

        # Lookup the object bound to the old name.  This has the side effect
        # of adding the object to the cache under the name 'old_name'.
        obj = self._lookup(old_name)

        # We are renaming a LOCAL context (ie. a folder)...
        if old_file.is_folder:
            # Create the new filename.
            new_filename = new_name
            new_file = File(join(self.path, new_filename))

            # Move the folder.
            old_file.move(new_file)

            # Update the 'Context' object.
            obj.path = new_file.path

            # Update the cache.
            self._cache[new_name] = obj
            del self._cache[old_name]

            # Refreshing the context makes sure that all of its contents
            # reflect the new name (i.e., sub-folders and files have the
            # correct path).
            #
            # fixme: This currently results in new copies of serialized
            # Python objects!  We need to be a bit more judicious in the
            # refresh.
            obj.refresh()

        # We are renaming a file...
        elif isinstance(obj, File):
            # Create the new filename.
            new_filename = new_name
            new_file = File(join(self.path, new_filename))

            # Move the file.
            old_file.move(new_file)

            # Update the 'File' object.
            obj.path = new_file.path

            # Update the cache.
            self._cache[new_name] = obj
            del self._cache[old_name]

        # We are renaming a serialized Python object...
        else:
            # Create the new filename.
            new_filename = new_name + old_file.ext
            new_file = File(join(self.path, new_filename))

            old_file.delete()

            # Update the cache.
            if old_name in self._cache:
                self._cache[new_name] = self._cache[old_name]
                del self._cache[old_name]

            # Force the creation of the new file.
            #
            # fixme: I'm not sure that this is really the place for this.  We
            # do it because often the 'name' of the object is actually an
            # attribute of the object itself, and hence we want the serialized
            # state to reflect the new name... Hmmm...
            self._rebind(new_name, obj)

        # Update the name to filename map.
        del self._name_to_filename_map[old_name]
        self._name_to_filename_map[new_name] = new_filename

        # Move any attributes over to the new name.
        if old_name in self._attributes:
            self._attributes[new_name] = self._attributes[old_name]
            del self._attributes[old_name]
            self._save_attributes()

        return

    def _create_subcontext(self, name):
        """ Creates a sub-context of this context. """

        path = join(self.path, name)

        # Create a directory.
        os.mkdir(path)

        # Create a sub-context that represents the directory.
        sub = self._context_factory(name, path)

        # Update the name to filename map.
        self._name_to_filename_map[name] = name

        # Update the cache.
        self._cache[name] = sub

        return sub

    def _destroy_subcontext(self, name):
        """ Destroys a sub-context of this context. """

        return self._unbind(name)

    def _list_names(self):
        """ Lists the names bound in this context. """

        return self._name_to_filename_map.keys()

    # fixme: YFI this is not part of the protected 'Context' interface so
    # what is it doing here?
    def get_unique_name(self, name):

        ext = splitext(name)[1]

        # specially handle '.py' files
        if ext != '.py':
            return super(PyFSContext, self).get_unique_name(name)

        body = splitext(name)[0]
        names = self.list_names()
        i = 2
        unique = name
        while unique in names:
            unique = body + '_' + str(i) + '.py'
            i += 1

        return unique

    ###########################################################################
    # Protected 'DirContext' interface.
    ###########################################################################

    def _get_attributes(self, name):
        """ Returns the attributes of an object in this context. """

        attributes = self._attributes.setdefault(name, {})

        return attributes.copy()

    def _set_attributes(self, name, attributes):
        """ Sets the attributes of an object in this context. """

        self._attributes[name] = attributes
        self._save_attributes()

        return

    ###########################################################################
    # Private interface.
    ###########################################################################

    def _get_filters(self):
        """ Returns the filters for this context. """

        return self.environment.get(self.FILTERS, [])

    def _get_object_serializers(self):
        """ Returns the object serializers for this context. """

        return self.environment.get(self.OBJECT_SERIALIZERS, [])

    def _context_factory(self, name, path):
        """ Create a sub-context. """

        return self.__class__(path=path, environment=self.environment)

    def _save_attributes(self):
        """ Saves all attributes to the attributes file. """

        path = join(self.path, self.ATTRIBUTES_FILE)

        f = file(path, 'wb')
        cPickle.dump(self._attributes, f, 1)
        f.close()

        return

    #### Trait initializers ###################################################

    def __name_to_filename_map_default(self):
        """ Initializes the '_name_to_filename' trait. """

        # fixme: We should have a generalized filter mechanism (instead of
        # just 'glob' patterns we should have filter objects that can be a bit
        # more flexible in how they do the filtering).
        patterns = [join(self.path, filter) for filter in self._get_filters()]

        name_to_filename_map = {}
        for filename in os.listdir(self.path):
            path = join(self.path, filename)
            for pattern in patterns:
                if path in glob.glob(pattern):
                    break

            else:
                for serializer in self._get_object_serializers():
                    if serializer.can_load(filename):
                        # fixme: We should probably get the name from the
                        # serializer instead of assuming that we can just
                        # drop the file exension.
                        name, ext = os.path.splitext(filename)
                        break

                else:
                    name = filename

                name_to_filename_map[name] = filename

        return name_to_filename_map

    def __attributes_default(self):
        """ Initializes the '_attributes' trait. """

        attributes_file = File(join(self.path, self.ATTRIBUTES_FILE))
        if attributes_file.is_file:
            f = file(attributes_file.path, 'rb')
            attributes = cPickle.load(f)
            f.close()

        else:
            attributes = {}

        return attributes

    #### Trait event handlers #################################################

    def _path_changed(self):
        """ Called when the context's path has changed. """

        basename = os.path.basename(self.path)

        self.name, ext = os.path.splitext(basename)

        return
Exemple #2
0
class VMPlotTab(PlotFrameTab):
    pf = Instance(EnzoStaticOutput)
    figure = Instance(Figure, args=())
    field = DelegatesTo('plot_spec')
    field_list = DelegatesTo('plot_spec')
    plot = Instance(VMPlot)
    axes = Instance(Axes)
    disp_width = Float(1.0)
    unit = Str('unitary')
    min_width = Property(Float, depends_on=['pf', 'unit'])
    max_width = Property(Float, depends_on=['pf', 'unit'])
    unit_list = Property(depends_on='pf')
    smallest_dx = Property(depends_on='pf')

    traits_view = View(VGroup(
        HGroup(Item('figure', editor=MPLVMPlotEditor(), show_label=False)),
        HGroup(
            Item('disp_width',
                 editor=RangeEditor(format="%0.2e",
                                    low_name='min_width',
                                    high_name='max_width',
                                    mode='logslider',
                                    enter_set=True),
                 show_label=False,
                 width=400.0),
            Item('unit', editor=EnumEditor(name='unit_list')),
        ), HGroup(Item('field', editor=EnumEditor(name='field_list')), )),
                       resizable=True)

    def __init__(self, **traits):
        super(VMPlotTab, self).__init__(**traits)
        self.axes = self.figure.add_subplot(111, aspect='equal')

    def _field_changed(self, old, new):
        self.plot.switch_z(new)
        self._redraw()

    @cached_property
    def _get_min_width(self):
        return 50.0 * self.smallest_dx * self.pf[self.unit]

    @cached_property
    def _get_max_width(self):
        return self.pf['unitary'] * self.pf[self.unit]

    @cached_property
    def _get_smallest_dx(self):
        return self.pf.h.get_smallest_dx()

    @cached_property
    def _get_unit_list(self):
        return self.pf.units.keys()

    def _unit_changed(self, old, new):
        self.disp_width = self.disp_width * self.pf[new] / self.pf[old]

    def _disp_width_changed(self, old, new):
        self.plot.set_width(new, self.unit)
        self._redraw()

    def _redraw(self):
        self.figure.canvas.draw()

    def recenter(self, event):
        xp, yp = event.xdata, event.ydata
        dx = abs(self.plot.xlim[0] - self.plot.xlim[1]) / self.plot.pix[0]
        dy = abs(self.plot.ylim[0] - self.plot.ylim[1]) / self.plot.pix[1]
        x = (dx * xp) + self.plot.xlim[0]
        y = (dy * yp) + self.plot.ylim[0]
        xi = lagos.x_dict[self.axis]
        yi = lagos.y_dict[self.axis]
        cc = self.center[:]
        cc[xi] = x
        cc[yi] = y
        self.plot.data.center = cc[:]
        self.plot.data.set_field_parameter('center', cc.copy())
        self.center = cc
Exemple #3
0
class ProcessObject(HasTraits):
    """
    Base class for all model component objects
    """
    name = Str
    
    #
    #This flag indicates if the object parameters have changed
    #
    modified = Bool(True)
    
    #
    #We could link each process-object to a node in an OCAF document
    #Not used yet.
    #
    label = Instance(TDF.TDF_Label)
    
    #
    #Parent TDF_label under which this label will be created
    #
    parent_label = Instance(TDF.TDF_Label)
    
    #
    #This is the output of the object. The property calls the execute method
    #to evaluate the result (which in turn calls up the tree)
    #
    shape = Property(Instance(TopoDS.TopoDS_Shape))
    
    #
    #Shadow trait which stores the cached shape
    #
    _shape = Instance(TopoDS.TopoDS_Shape)

    #
    #A list of all inputs, for the benefit of the TreeEditor
    #
    _inputs = List
      
    #
    #We hook up listeners to each input to listen to changes in their
    #modification trait. Hence, modifications propagate down the tree
    #
    @on_trait_change("+process_input")
    def on_input_change(self, obj, name, vold, vnew):
        print "ch", vold, vnew
        if vold is not None:
            vold.on_trait_change(self.on_modify, 'modified', remove=True)
            if vold in self._input_set:
                del self._inputs[self._inputs.index(vold)]
        
        vnew.on_trait_change(self.on_modify, 'modified')
        self._inputs.append(vnew)
        
    def _parent_label_changed(self, old_label, new_label):
        ts = TDF.TDF_TagSource()
        self.label = ts.NewChild(new_label)
        
    def on_modify(self, vnew):
        if vnew:
            self.modified = False
            self.modified = True
        
    def _get_shape(self):
        if self.modified:
            shape = self.execute()
            self._shape = shape
            self.modified = False
            return shape
        else:
            return self._shape
        
    def execute(self):
        """return a TopoDS_Shape object"""
        raise NotImplementedError
    
    def update_naming(self, make_shape):
        """called within the Execute method, to update the Naming
        Structure. This is key to getting Topological Naming to work"""
        raise NotImplementedError
Exemple #4
0
class ImageActor(Module):

    # An image actor.
    actor = Instance(tvtk.ImageActor, allow_none=False, record=True)

    input_info = PipelineInfo(datasets=['image_data'],
                              attribute_types=['any'],
                              attributes=['any'])    

    # An ImageMapToColors TVTK filter to adapt datasets without color
    # information
    image_map_to_color = Instance(tvtk.ImageMapToColors, (),
                                            allow_none=False, record=True)

    map_scalars_to_color = Bool

    _force_map_scalars_to_color = Property(depends_on='module_manager.source')

    ########################################
    # The view of this module.

    view = View(Group(Item(name='actor', style='custom',
                           resizable=True),
                      show_labels=False, label='Actor'),
                Group(
                      Group(Item('map_scalars_to_color',
                            enabled_when='not _force_map_scalars_to_color')),
                      Item('image_map_to_color', style='custom',
                            enabled_when='map_scalars_to_color',
                            show_label=False),
                      label='Map Scalars',
                     ),
                width=500,
                height=600,
                resizable=True)

    ######################################################################
    # `Module` interface
    ######################################################################
    def setup_pipeline(self):
        self.actor = tvtk.ImageActor()
        
    @on_trait_change('map_scalars_to_color,image_map_to_color.[output_format,pass_alpha_to_output]')
    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.
        """
        mm = self.module_manager
        if mm is None:
            return
        src = mm.source
        if self._force_map_scalars_to_color:
            self.set(map_scalars_to_color=True, trait_change_notify=False)
        if self.map_scalars_to_color: 
            self.image_map_to_color.input = src.outputs[0]
            self.image_map_to_color.lookup_table = mm.scalar_lut_manager.lut
            self.actor.input = self.image_map_to_color.output
        else:
            self.actor.input = src.outputs[0]
        self.pipeline_changed = True

    def update_data(self):
        """Override this method so that it flushes the vtk pipeline if
        that is necessary.
        """
        # Just set data_changed, the component should do the rest.
        self.data_changed = True

    ######################################################################
    # Non-public methods.
    ######################################################################
    def _actor_changed(self, old, new):
        if old is not None:
            self.actors.remove(old)
            old.on_trait_change(self.render, remove=True)
        self.actors.append(new)
        new.on_trait_change(self.render)

    def _get__force_map_scalars_to_color(self):
        mm = self.module_manager
        if mm is None:
            return False
        src = mm.source
        return not isinstance(src.outputs[0].point_data.scalars, 
                                                    tvtk.UnsignedCharArray)
class InstanceContextAdapter(ContextAdapter):
    """ Context adapter for Python instances. """

    #### 'Context' interface ##################################################

    # The name of the context within its own namespace.
    namespace_name = Property(Str)

    #### 'InstanceContextAdapter' interface ###################################

    # By default every public attribute of an instance is exposed. Use the
    # following traits to either include or exclude attributes as appropriate.
    #
    # Regular expressions that describe the names of attributes to include.
    include = List(Str)

    # Regular expressions that describe the names of attributes to exclude.  By
    # default we exclude 'protected' and 'private' attributes and any
    # attributes that are artifacts of the traits mechanism.
    exclude = List(Str, ['_', 'trait_'])

    ###########################################################################
    # 'Context' interface.
    ###########################################################################

    #### Properties ###########################################################

    def _get_namespace_name(self):
        """ Returns the name of the context within its own namespace. """

        base = self.context.namespace_name
        if len(base) > 0:
            base += '/'

        names = self.context.search(self.adaptee)

        return base + names[0]

    ###########################################################################
    # Protected 'Context' interface.
    ###########################################################################

    def _is_bound(self, name):
        """ Is a name bound in this context? """

        return name in self._list_names()

    def _lookup(self, name):
        """ Looks up a name in this context. """

        obj = getattr(self.adaptee, name)

        return naming_manager.get_object_instance(obj, name, self)

    def _lookup_binding(self, name):
        """ Looks up the binding for a name in this context. """

        return Binding(name=name, obj=self._lookup(name), context=self)

    def _bind(self, name, obj):
        """ Binds a name to an object in this context. """

        state = naming_manager.get_state_to_bind(obj, name, self)
        setattr(self.adaptee, name, state)

        return

    def _rebind(self, name, obj):
        """ Rebinds a name to an object in this context. """

        self._bind(name, obj)

        return

    def _unbind(self, name):
        """ Unbinds a name from this context. """

        delattr(self.adaptee, name)

        return

    def _rename(self, old_name, new_name):
        """ Renames an object in this context. """

        # Bind the new name.
        setattr(self.adaptee, new_name, self._lookup(old_name))

        # Unbind the old one.
        delattr(self.adaptee, old_name)

        return

    def _create_subcontext(self, name):
        """ Creates a sub-context of this context. """

        raise OperationNotSupportedError()

    def _destroy_subcontext(self, name):
        """ Destroys a sub-context of this context. """

        raise OperationNotSupportedError()

    def _list_bindings(self):
        """ Lists the bindings in this context. """

        bindings = []
        for name in self._list_names():
            try:
                obj = self._lookup(name)
                bindings.append(Binding(name=name, obj=obj, context=self))

            # We get attribute errors when we try to look up Event traits (they
            # are write-only).
            except AttributeError:
                pass

        return bindings

    def _list_names(self):
        """ Lists the names bound in this context. """

        return self._get_public_attribute_names(self.adaptee)

    ###########################################################################
    # Private interface.
    ###########################################################################

    def _get_public_attribute_names(self, obj):
        """ Returns the names of an object's public attributes. """

        if isinstance(obj, HasTraits):
            names = obj.trait_names()

        elif hasattr(obj, '__dict__'):
            names = self.adaptee.__dict__.keys()

        else:
            names = []

        return [name for name in names if self._is_exposed(name)]

    def _is_exposed(self, name):
        """ Returns True iff a name should be exposed. """

        if len(self.include) > 0:
            is_exposed = self._matches(self.include, name)

        elif len(self.exclude) > 0:
            is_exposed = not self._matches(self.exclude, name)

        else:
            is_exposed = True

        return is_exposed

    def _matches(self, expressions, name):
        """ Returns True iff a name matches any of a list of expressions. """

        for expression in expressions:
            if re.match(expression, name) is not None:
                matches = True
                break

        else:
            matches = False

        return matches
Exemple #6
0
class DataSetBrowser(HasTraits):
    """
    A class that allows browsing of a DataSet object with sliders
    to navigate through plates, images within plates, and objects 
    within images.
    """

    view = View(VGroup(
        HGroup(
            Item('image_plots',
                 editor=ComponentEditor(size=(50, 50)),
                 show_label=False), ),
        HGroup(
            Item('plots',
                 editor=ComponentEditor(size=(250, 300)),
                 show_label=False), ),
        Group(
            Item('object_index',
                 editor=RangeEditor(low=1,
                                    high_name='num_objects',
                                    mode='slider')),
            Item('image_index',
                 editor=RangeEditor(low=1,
                                    high_name='num_images',
                                    mode='slider')),
            Item('plate_index',
                 editor=RangeEditor(low=1,
                                    high_name='num_plates',
                                    mode='slider')),
        ),
        HGroup(
            Item('num_internal_knots',
                 label='Number of internal spline knots'),
            Item('smoothing', label='Amount of smoothing applied'))),
                height=700,
                width=800,
                resizable=True)

    # Chaco plot
    gfp_plot = Instance(Plot)
    sil_plot = Instance(Plot)
    image_plots = Instance(HPlotContainer)
    rotated_plot = Instance(Plot)
    plots = Instance(GridPlotContainer)
    #legends = Instance(VPlotContainer)
    # DataSet being viewed
    dataset = Instance(DataSet)

    # Plate object currently being examined
    current_plate = Instance(Plate)

    # ImageSilhouette object currently being examined
    current_image = Instance(ImageSilhouette)

    # ObjectSilhouette object currently being examined
    current_object = Instance(ObjectSilhouette)

    # Index traits that control the selected plate/image/object
    plate_index = Int(1)
    image_index = Int(1)
    object_index = Int(1)

    # Number of plates, images, and objects in the current context
    num_plates = Property(Int, depends_on='dataset')
    num_images = Property(Int, depends_on='current_plate')
    num_objects = Property(Int, depends_on='current_image')
    num_internal_knots = Range(1, 20, 3)
    smoothing = Range(0.0, 2.0, 0)

    def __init__(self, *args, **kwargs):
        """Construct a DataSetBrowser from the specified DataSet object."""
        super(DataSetBrowser, self).__init__(*args, **kwargs)
        self.current_plate = self.dataset[self.plate_index - 1]
        self.current_image = self.current_plate[self.image_index - 1]
        self.current_object = self.current_image[self.object_index - 1]
        self.sil_plot = Plot()
        self._object_index_changed()

    ######################### Private interface ##########################

    def _plate_index_changed(self):
        """Handle the plate index changing."""
        try:
            self.current_plate = self.dataset[self.plate_index - 1]
        except IndexError:
            self.current_plate = None
        self.image_index = 1
        self._image_index_changed()

    def _image_index_changed(self):
        """Handle the image index slider changing."""
        try:
            self.current_image = self.current_plate[self.image_index - 1]
        except IndexError:
            self.current_image = None
        self.object_index = 1
        self._object_index_changed()

    def _object_index_changed(self):
        """Handle object index slider changing."""
        try:
            self.current_object = self.current_image[self.object_index - 1]

            # Display
            sil = self.current_object.image
            self._update_img_plot('sil_plot', sil, 'Extracted mask')

            # .T to get major axis horizontal
            rotated = self.current_object.aligned_version.image.T
            self._update_img_plot('rotated_plot', rotated, 'Aligned mask')

            self.image_plots = HPlotContainer(self.sil_plot,
                                              self.rotated_plot,
                                              valign="top",
                                              bgcolor="transparent")

            self._update_spline_plot()

        except IndexError:
            self.current_object = None

    def _get_num_plates(self):
        """Return the number of plates in the currently viewed dataset."""
        return len(self.dataset)

    def _get_num_images(self):
        """Return the number of images in the currently viewed plate."""
        return len(self.current_plate)

    def _get_num_objects(self):
        """Return the number of objects in the currently viewed image."""
        return len(self.current_image)

    def _update_img_plot(self, plot_name, image, title):
        """Update an image plot."""
        plotdata = ArrayPlotData(imagedata=image)
        xbounds = (0, image.shape[1] - 1)
        ybounds = (0, image.shape[0] - 1)

        plot = Plot(plotdata)
        plot.aspect_ratio = float(xbounds[1]) / float(ybounds[1])
        plot.img_plot("imagedata",
                      colormap=bone,
                      xbounds=xbounds,
                      ybounds=ybounds)
        plot.title = title

        setattr(self, plot_name, plot)
        getattr(self, plot_name).request_redraw()

    def _update_spline_plot(self):
        """Update the spline plot."""
        knots = np.mgrid[0:1:((self.num_internal_knots + 2) * 1j)][1:-1]
        medial_repr = self.current_object.aligned_version.medial_repr
        dependent_variable = np.mgrid[0:1:(medial_repr.length * 1j)]
        laplacian = ndimage.gaussian_laplace(medial_repr.width_curve,
                                             self.smoothing,
                                             mode='constant',
                                             cval=np.nan)
        m_spline = LSQUnivariateSpline(dependent_variable,
                                       medial_repr.medial_axis, knots)
        w_spline = LSQUnivariateSpline(dependent_variable,
                                       medial_repr.width_curve, knots)
        # sample at double the frequency
        spl_dep_var = np.mgrid[0:1:(medial_repr.length * 2j)]
        plots = self.plots
        if plots is None:
            # Render the plot for the first time.
            plotdata = ArrayPlotData(
                medial_x=dependent_variable,
                medial_y=medial_repr.medial_axis,
                width_x=dependent_variable,
                width_y=medial_repr.width_curve,
                medial_spline_x=spl_dep_var,
                medial_spline_y=m_spline(spl_dep_var),
                width_spline_x=spl_dep_var,
                width_spline_y=w_spline(spl_dep_var),
                laplacian_y=laplacian,
            )
            plot = Plot(plotdata)

            # Width data
            self._width_data_renderer, = plot.plot(
                ("width_x", "width_y"),
                type="line",
                color="blue",
                name="Original width curve data")

            filterdata = ArrayPlotData(x=dependent_variable,
                                       laplacian=laplacian)
            filterplot = Plot(filterdata)
            self._laplacian_renderer, = filterplot.plot(
                ("x", "laplacian"),
                type="line",
                color="black",
                name="Laplacian-of-Gaussian")

            # Titles for plot & axes
            plot.title = "Width curves"
            plot.x_axis.title = "Normalized position on medial axis"
            plot.y_axis.title = "Fraction of medial axis width"

            # Legend mangling stuff
            legend = plot.legend
            plot.legend = None
            legend.set(component=None,
                       visible=True,
                       resizable="",
                       auto_size=True,
                       bounds=[250, 70],
                       padding_top=plot.padding_top)

            filterlegend = filterplot.legend
            filterplot.legend = None
            filterlegend.set(component=None,
                             visible=True,
                             resizable="",
                             auto_size=True,
                             bounds=[250, 50],
                             padding_top=filterplot.padding_top)

            self.plots = GridPlotContainer(plot,
                                           legend,
                                           filterplot,
                                           filterlegend,
                                           shape=(2, 2),
                                           valign="top",
                                           bgcolor="transparent")

        else:

            # Update the real width curve
            self._width_data_renderer.index.set_data(dependent_variable)
            self._width_data_renderer.value.set_data(medial_repr.width_curve)

            # Render the Laplacian
            self._laplacian_renderer.index.set_data(dependent_variable)
            self._laplacian_renderer.value.set_data(laplacian)

    def _num_internal_knots_changed(self):
        """Hook to update the plot when we change the number of knots."""
        self._update_spline_plot()

    def _smoothing_changed(self):
        """Hook to update the plot when we change the smoothing parameter."""
        self._update_spline_plot()
class TVTKDocument(HasTraits):
    tree_editor = Instance(TreeEditor)
    nodes = Any
    search = Str
    search_result = List(Instance(TVTKClass))
    search_result_str = Property(depends_on="search_result")
    search_result_index = Int
    object_class = Instance(TVTKClass)
    tree_selected = Instance(TVTKClass)
    mark_lines = List(Int)
    current_line = Int
    current_document = Str
    show_tree = Bool(True)

    def default_traits_view(self):
        view = View(
            HSplit(VSplit(
                Item("object_class",
                     editor=self.tree_editor,
                     show_label=False,
                     visible_when="object.show_tree"),
                Group(Item("search", label=u"搜索"),
                      Item("search_result_str",
                           show_label=False,
                           editor=ListStrEditor(
                               editable=False,
                               selected_index="search_result_index")),
                      label="Search"),
            ),
                   Item("current_document",
                        style="custom",
                        show_label=False,
                        editor=CodeEditor(lexer="null",
                                          search="top",
                                          line="current_line",
                                          mark_lines="mark_lines",
                                          mark_color=0xff7777)),
                   id="tvtkdoc.hsplit"),
            width=700,
            height=500,
            resizable=True,
            title=u"TVTK文档浏览器",
            id="tvtkdoc",
            handler=TVTKDocumentHandler(),
        )
        return view

    @cached_property
    def _get_search_result_str(self):
        return [obj.name for obj in self.search_result]

    def _search_changed(self):
        if len(self.search) < 3: return
        result = []
        for cls in TVTKClass.Classes.values():
            if self.search.islower():
                if self.search in cls.doc.lower():
                    result.append(cls)
            else:
                if self.search in cls.doc:
                    result.append(cls)

        result.sort(key=lambda obj: obj.name)
        self.search_result = result

    def _search_result_index_changed(self):
        if self.search_result_index >= 0:
            self.tree_selected = self.search_result[self.search_result_index]

    def _object_class_default(self):
        obj = TVTKClass(name="Object")
        self.tree_selected = obj.children[0]
        return obj

    def _tree_editor_default(self):
        return TreeEditor(editable=False,
                          hide_root=True,
                          nodes=self.nodes,
                          selected="tree_selected")

    def _nodes_default(self):
        nodes = [
            ObjectTreeNode(
                node_for=[TVTKClass],
                children="children",
                label="name",
                auto_open=True,
                copy=True,
                delete=True,
                rename=True,
            )
        ]
        return nodes

    def _tree_selected_changed(self):
        self.current_document = self.tree_selected.doc
        if len(self.search) < 3:
            self.mark_lines = []
            self.current_line = 1
            return
        doc = self.tree_selected.doc
        if self.search.islower():
            doc = doc.lower()
        lines = doc.split("\n")
        result = []
        for i, line in enumerate(lines):
            if self.search in line:
                result.append(i + 1)
        self.mark_lines = result
        if len(result) > 0:
            self.current_line = result[0]
        else:
            self.current_line = 1
Exemple #8
0
class Pulsed(ManagedJob, GetSetItemsMixin):
    """Defines a pulsed measurement."""
    keep_data = Bool(
        False)  # helper variable to decide whether to keep existing data

    resubmit_button = Button(
        label='resubmit',
        desc=
        'Submits the measurement to the job manager. Tries to keep previously acquired data. Behaves like a normal submit if sequence or time bins have changed since previous run.'
    )

    sequence = Instance(list, factory=list)

    record_length = Float(value=0,
                          desc='length of acquisition record [ms]',
                          label='record length [ms] ',
                          mode='text')

    count_data = Array(value=np.zeros(2))

    run_time = Float(value=0.0, label='run time [ns]', format_str='%.f')
    stop_time = Range(
        low=1.,
        value=np.inf,
        desc='Time after which the experiment stops by itself [s]',
        label='Stop time [s]',
        mode='text',
        auto_set=False,
        enter_set=True)

    tau_begin = Range(low=0.,
                      high=1e5,
                      value=300.,
                      desc='tau begin [ns]',
                      label='repetition',
                      mode='text',
                      auto_set=False,
                      enter_set=True)
    tau_end = Range(low=1.,
                    high=1e5,
                    value=4000.,
                    desc='tau end [ns]',
                    label='N repetition',
                    mode='text',
                    auto_set=False,
                    enter_set=True)
    tau_delta = Range(low=1.,
                      high=1e5,
                      value=50.,
                      desc='delta tau [ns]',
                      label='delta',
                      mode='text',
                      auto_set=False,
                      enter_set=True)

    tau = Array(value=np.array((0., 1.)))
    sequence_points = Int(value=2, label='number of points', mode='text')

    laser_SST = Range(low=1.,
                      high=5e6,
                      value=200.,
                      desc='laser for SST [ns]',
                      label='laser_SST[ns]',
                      mode='text',
                      auto_set=False,
                      enter_set=True)
    wait_SST = Range(low=1.,
                     high=5e6,
                     value=1000.,
                     desc='wait for SST[ns]',
                     label='wait_SST [ns]',
                     mode='text',
                     auto_set=False,
                     enter_set=True)
    N_shot = Range(low=1,
                   high=20e5,
                   value=2e3,
                   desc='number of shots in SST',
                   label='N_shot',
                   mode='text',
                   auto_set=False,
                   enter_set=True)

    laser = Range(low=1.,
                  high=5e4,
                  value=3000,
                  desc='laser [ns]',
                  label='laser [ns]',
                  mode='text',
                  auto_set=False,
                  enter_set=True)
    wait = Range(low=1.,
                 high=5e4,
                 value=5000.,
                 desc='wait [ns]',
                 label='wait [ns]',
                 mode='text',
                 auto_set=False,
                 enter_set=True)

    freq_center = Range(low=1,
                        high=20e9,
                        value=2.71e9,
                        desc='frequency [Hz]',
                        label='MW freq[Hz]',
                        editor=TextEditor(auto_set=False,
                                          enter_set=True,
                                          evaluate=float,
                                          format_str='%.4e'))
    power = Range(low=-100.,
                  high=25.,
                  value=-26,
                  desc='power [dBm]',
                  label='power[dBm]',
                  editor=TextEditor(auto_set=False,
                                    enter_set=True,
                                    evaluate=float))
    freq = Range(low=1,
                 high=20e9,
                 value=2.71e9,
                 desc='frequency [Hz]',
                 label='freq [Hz]',
                 editor=TextEditor(auto_set=False,
                                   enter_set=True,
                                   evaluate=float,
                                   format_str='%.4e'))
    pi = Range(low=0.,
               high=5e4,
               value=2e3,
               desc='pi pulse length',
               label='pi [ns]',
               mode='text',
               auto_set=False,
               enter_set=True)

    amp = Range(low=0.,
                high=1.0,
                value=1.0,
                desc='Normalized amplitude of waveform',
                label='Amp',
                mode='text',
                auto_set=False,
                enter_set=True)
    vpp = Range(low=0.,
                high=4.5,
                value=0.6,
                desc='Amplitude of AWG [Vpp]',
                label='Vpp',
                mode='text',
                auto_set=False,
                enter_set=True)

    sweeps = Range(low=1.,
                   high=1e4,
                   value=1e2,
                   desc='number of sweeps',
                   label='sweeps',
                   mode='text',
                   auto_set=False,
                   enter_set=True)
    expected_duration = Property(
        trait=Float,
        depends_on='sweeps,sequence',
        desc='expected duration of the measurement [s]',
        label='expected duration [s]')
    elapsed_sweeps = Float(value=0,
                           desc='Elapsed Sweeps ',
                           label='Elapsed Sweeps ',
                           mode='text')
    elapsed_time = Float(value=0,
                         desc='Elapsed Time [ns]',
                         label='Elapsed Time [ns]',
                         mode='text')
    progress = Int(value=0,
                   desc='Progress [%]',
                   label='Progress [%]',
                   mode='text')

    load_button = Button(desc='compile and upload waveforms to AWG',
                         label='load')
    reload = True

    readout_interval = Float(
        1,
        label='Data readout interval [s]',
        desc='How often data read is requested from nidaq')
    samples_per_read = Int(
        200,
        label='# data points per read',
        desc=
        'Number of data points requested from nidaq per read. Nidaq will automatically wait for the data points to be aquired.'
    )

    def submit(self):
        """Submit the job to the JobManager."""
        self.keep_data = False
        ManagedJob.submit(self)

    def resubmit(self):
        """Submit the job to the JobManager."""
        self.keep_data = True
        ManagedJob.submit(self)

    def _resubmit_button_fired(self):
        """React to start button. Submit the Job."""
        self.resubmit()

    def generate_sequence(self):
        return []

    def prepare_awg(self):
        """ override this """
        AWG.reset()

    def _load_button_changed(self):
        self.load()

    def load(self):
        self.reload = True
        # update record_length, in ms
        self.record_length = self.N_shot * (self.pi + self.laser_SST +
                                            self.wait_SST) * 1e-6
        #make sure tau is updated
        self.tau = np.arange(self.tau_begin, self.tau_end, self.tau_delta)
        self.prepare_awg()
        self.reload = False

    @cached_property
    def _get_expected_duration(self):
        sequence_length = 0
        for step in self.sequence:
            sequence_length += step[1]
        return self.sweeps * sequence_length * 1e-9

    def _get_sequence_points(self):
        return len(self.tau)

    def apply_parameters(self):
        """Apply the current parameters and decide whether to keep previous data."""
        """if load button is not used, make sure tau is generated"""
        if (self.tau.shape[0] == 2):
            tau = np.arange(self.tau_begin, self.tau_end, self.tau_delta)
            self.tau = tau

        self.sequence_points = self._get_sequence_points()
        self.measurement_points = self.sequence_points * int(self.sweeps)
        sequence = self.generate_sequence()

        if self.keep_data and sequence == self.sequence:  # if the sequence and time_bins are the same as previous, keep existing data

            self.previous_sweeps = self.elapsed_sweeps
            self.previous_elapsed_time = self.elapsed_time
            self.keep_data = True  # when job manager stops and starts the job, data should be kept. Only new submission should clear data.
        else:

            #self.old_count_data = np.zeros((n_laser, n_bins))
            #self.check = True

            self.count_data = np.zeros(self.measurement_points)
            self.old_count_data = np.zeros(self.measurement_points)
            self.previous_sweeps = 0
            self.previous_elapsed_time = 0.0
            self.run_time = 0.0
            self.keep_data = True  # when job manager stops and starts the job, data should be kept. Only new submission should clear data.

        self.sequence = sequence

    def _run(self):
        """Acquire data."""

        try:  # try to run the acquisition from start_up to shut_down
            self.state = 'run'
            self.apply_parameters()

            PG.High([])

            self.prepare_awg()
            MW.setFrequency(self.freq_center)
            MW.setPower(self.power)

            AWG.run()
            time.sleep(4.0)
            PG.Sequence(self.sequence, loop=True)

            if CS.configure(
            ) != 0:  # initialize and start nidaq gated counting task, return 0 if succuessful
                print 'error in nidaq'
                return

            start_time = time.time()

            aquired_data = np.empty(
                0)  # new data will be appended to this array

            while True:

                self.thread.stop_request.wait(self.readout_interval)
                if self.thread.stop_request.isSet():
                    logging.getLogger().debug('Caught stop signal. Exiting.')
                    break

                #threading.current_thread().stop_request.wait(self.readout_interval) # wait for some time before new read command is given. not sure if this is neccessary
                #if threading.current_thread().stop_request.isSet():
                #break

                points_left = self.measurement_points - len(aquired_data)

                self.elapsed_time = self.previous_elapsed_time + time.time(
                ) - start_time
                self.run_time += self.elapsed_time

                new_data = CS.read_gated_counts(SampleLength=min(
                    self.samples_per_read, points_left
                ))  # do not attempt to read more data than neccessary

                aquired_data = np.append(
                    aquired_data, new_data[:min(len(new_data), points_left)])

                self.count_data[:len(
                    aquired_data
                )] = aquired_data[:]  # length of trace may not change due to plot, so just copy aquired data into trace

                sweeps = len(aquired_data) / self.sequence_points
                self.elapsed_sweeps += self.previous_sweeps + sweeps
                self.progress = int(100 * len(aquired_data) /
                                    self.measurement_points)

                if self.progress > 99.9:
                    break

            MW.Off()
            PG.High(['laser', 'mw'])
            AWG.stop()

            if self.elapsed_sweeps < self.sweeps:
                self.state = 'idle'
            else:
                self.state = 'done'

        except:  # if anything fails, log the exception and set the state
            logging.getLogger().exception(
                'Something went wrong in pulsed loop.')
            self.state = 'error'

        finally:
            CS.stop_gated_counting()  # stop nidaq task to free counters

    get_set_items = [
        '__doc__', 'record_length', 'laser', 'wait', 'sequence', 'count_data',
        'run_time', 'tau_begin', 'tau_end', 'tau_delta', 'tau', 'freq_center',
        'power', 'laser_SST', 'wait_SST', 'amp', 'vpp', 'pi', 'freq', 'N_shot',
        'readout_interval', 'samples_per_read'
    ]

    traits_view = View(
        VGroup(
            HGroup(
                Item('load_button', show_label=False),
                Item('submit_button', show_label=False),
                Item('remove_button', show_label=False),
                Item('resubmit_button', show_label=False),
                Item('priority'),
            ),
            HGroup(
                Item('freq', width=-70),
                Item('freq_center', width=-70),
                Item('amp', width=-30),
                Item('vpp', width=-30),
                Item('power', width=-40),
                Item('pi', width=-70),
            ),
            HGroup(
                Item('laser', width=-60),
                Item('wait', width=-60),
                Item('laser_SST', width=-50),
                Item('wait_SST', width=-50),
            ),
            HGroup(
                Item('samples_per_read', width=-50),
                Item('N_shot', width=-50),
                Item('record_length', style='readonly'),
            ),
            HGroup(
                Item('tau_begin', width=30),
                Item('tau_end', width=30),
                Item('tau_delta', width=30),
            ),
            HGroup(
                Item('state', style='readonly'),
                Item('run_time', style='readonly', format_str='%.f',
                     width=-60),
                Item('sweeps',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float,
                                       format_func=lambda x: '%.3e' % x),
                     width=-50),
                Item('expected_duration',
                     style='readonly',
                     editor=TextEditor(evaluate=float,
                                       format_func=lambda x: '%.2f' % x),
                     width=30),
                Item('progress', style='readonly'),
                Item('elapsed_time',
                     style='readonly',
                     editor=TextEditor(auto_set=False,
                                       enter_set=True,
                                       evaluate=float,
                                       format_func=lambda x: ' %.f' % x),
                     width=-50),
            ),
        ),
        title='Pulsed_SST Measurement',
    )
Exemple #9
0
class DatasetManager(HasTraits):

    # The TVTK dataset we manage.
    dataset = Instance(tvtk.DataSet)

    # Our output, this is the dataset modified by us with different
    # active arrays.
    output = Property(Instance(tvtk.DataSet))

    # The point scalars for the dataset.  You may manipulate the arrays
    # in-place.  However adding new keys in this dict will not set the
    # data in the `dataset` for that you must explicitly call
    # `add_array`.
    point_scalars = Dict(Str, Array)
    # Point vectors.
    point_vectors = Dict(Str, Array)
    # Point tensors.
    point_tensors = Dict(Str, Array)

    # The cell scalars for the dataset.
    cell_scalars = Dict(Str, Array)
    cell_vectors = Dict(Str, Array)
    cell_tensors = Dict(Str, Array)

    # This filter allows us to change the attributes of the data
    # object and will ensure that the pipeline is properly taken care
    # of.  Directly setting the array in the VTK object will not do
    # this.
    _assign_attribute = Instance(tvtk.AssignAttribute, args=(),
                                 allow_none=False)


    ######################################################################
    # Public interface.
    ######################################################################
    def add_array(self, array, name, category='point'):
        """
        Add an array to the dataset to specified category ('point' or
        'cell').
        """
        assert len(array.shape) <= 2, "Only 2D arrays can be added."
        data = getattr(self.dataset, '%s_data'%category)
        if len(array.shape) == 2:
            assert array.shape[1] in [1, 3, 4, 9], \
                    "Only Nxm arrays where (m in [1,3,4,9]) are supported"
            va = tvtk.to_tvtk(array2vtk(array))
            va.name = name
            data.add_array(va)
            mapping = {1:'scalars', 3: 'vectors', 4: 'scalars', 
                       9: 'tensors'}
            dict = getattr(self, '%s_%s'%(category,
                                          mapping[array.shape[1]]))
            dict[name] = array
        else:
            va = tvtk.to_tvtk(array2vtk(array))
            va.name = name
            data.add_array(va)
            dict = getattr(self, '%s_scalars'%(category))
            dict[name] = array

    def remove_array(self, name, category='point'):
        """Remove an array by its name and optional category (point and
        cell).  Returns the removed array.
        """
        type = self._find_array(name, category)
        data = getattr(self.dataset, '%s_data'%category)
        data.remove_array(name)
        d = getattr(self, '%s_%s'%(category, type))
        return d.pop(name)

    def rename_array(self, name1, name2, category='point'):
        """Rename a particular array from `name1` to `name2`.
        """
        type = self._find_array(name1, category)
        data = getattr(self.dataset, '%s_data'%category)
        arr = data.get_array(name1)
        arr.name = name2
        d = getattr(self, '%s_%s'%(category, type))
        d[name2] = d.pop(name1)

    def activate(self, name, category='point'):
        """Make the specified array the active one.
        """
        type = self._find_array(name, category)
        self._activate_data_array(type, category, name)

    def update(self):
        """Update the dataset when the arrays are changed.
        """
        self.dataset.modified()
        self._assign_attribute.update()

    ######################################################################
    # Non-public interface.
    ######################################################################
    def _dataset_changed(self, value):
        self._setup_data()
        self._assign_attribute.input = value

    def _get_output(self):
        return self._assign_attribute.output

    def _setup_data(self):
        """Updates the arrays from what is available in the input data.
        """
        input = self.dataset
        pnt_attr, cell_attr = get_all_attributes(input)

        self._setup_data_arrays(cell_attr, 'cell')
        self._setup_data_arrays(pnt_attr, 'point')
     
    def _setup_data_arrays(self, attributes, d_type):
        """Given the dict of the attributes from the
        `get_all_attributes` function and the data type (point/cell)
        data this will setup the object and the data.  
        """
        attrs = ['scalars', 'vectors', 'tensors']
        aa = self._assign_attribute
        input = self.dataset
        data = getattr(input, '%s_data'%d_type)
        for attr in attrs:
            values = attributes[attr]
            # Get the arrays from VTK, create numpy arrays and setup our
            # traits.
            arrays = {}
            for name in values:
                va = data.get_array(name)
                npa = va.to_array()
                # Now test if changes to the numpy array are reflected
                # in the VTK array, if they are we are set, else we
                # have to set the VTK array back to the numpy array.
                if len(npa.shape) > 1:
                    old = npa[0,0]
                    npa[0][0] = old - 1
                    if abs(va[0][0] - npa[0,0]) > 1e-8:
                        va.from_array(npa)
                    npa[0][0] = old
                else:
                    old = npa[0]
                    npa[0] = old - 1
                    if abs(va[0] - npa[0]) > 1e-8:
                        va.from_array(npa)
                    npa[0] = old
                arrays[name] = npa

            setattr(self, '%s_%s'%(d_type, attr), arrays)

    def _activate_data_array(self, data_type, category, name):
        """Activate (or deactivate) a particular array.

        Given the nature of the data (scalars, vectors etc.) and the
        type of data (cell or points) it activates the array given by
        its name.

        Parameters:
        -----------

        data_type: one of 'scalars', 'vectors', 'tensors'
        category: one of 'cell', 'point'.
        name: string of array name to activate.
        """
        input = self.dataset
        data = None
        data = getattr(input, category + '_data')
        method = getattr(data, 'set_active_%s'%data_type)
        if len(name) == 0:
            # If the value is empty then we deactivate that attribute.
            method(None)
        else:
            aa = self._assign_attribute
            method(name)
            aa.assign(name, data_type.upper(), category.upper() +'_DATA')
            aa.update()

    def _find_array(self, name, category='point'):
        """Return information on which kind of attribute contains the
        specified named array in a particular category."""
        types = ['scalars', 'vectors', 'tensors']
        for type in types:
            attr = '%s_%s'%(category, type)
            d = getattr(self, attr)
            if name in d.keys():
                return type
        raise KeyError('No %s array named %s available in dataset'
                        %(category, name))
Exemple #10
0
class FreeStressPullout(RF):
    '''Pullout of fiber from a matrix; stress criterion for debonding, free fiber end'''

    # implements( IRF )

    title = Str('one sided pull-out - short fiber with trilinear bond law')

    E_f = Float(200e+3,
                auto_set=False,
                enter_set=True,
                desc='filament stiffness [N/mm2]',
                distr=['uniform', 'norm'])

    d = Float(0.3,
              auto_set=False,
              enter_set=True,
              desc='filament diameter [mm]',
              distr=['uniform', 'norm'])

    z = Float(0.0,
              auto_set=False,
              enter_set=True,
              desc='fiber centroid distance from crack [mm]',
              distr=['uniform'])

    L_f = Float(17.0,
                auto_set=False,
                enter_set=True,
                desc='fiber length [mm]',
                distr=['uniform', 'norm'])

    k = Float(1.76,
              auto_set=False,
              enter_set=True,
              desc='bond shear stiffness [N/mm]',
              distr=['uniform', 'norm'])

    qf = Float(3.76,
               auto_set=False,
               enter_set=True,
               desc='bond shear stress [N/mm2]',
               distr=['uniform', 'norm'])

    qy = Float(19.76,
               auto_set=False,
               enter_set=True,
               desc='debbonding stress [N/mm2]',
               distr=['uniform', 'norm'])

    fu = Float(500,
               auto_set=False,
               enter_set=True,
               desc='fiber breaking stress [N/mm2]',
               distr=['uniform', 'norm'])

    l = Float(0.001,
              auto_set=False,
              enter_set=True,
              desc='free length',
              distr=['uniform', 'norm'])

    f = Float(0.03,
              auto_set=False,
              enter_set=True,
              desc='snubbing coefficient',
              distr=['uniform', 'norm'])

    phi = Float(0.0,
                auto_set=False,
                enter_set=True,
                desc='inclination angle',
                distr=['cos'])

    accuracy = Int(50, auto_set=False, enter_set=True)

    include_fu = Bool(False)

    u = Float(ctrl_range=(0, 0.016, 20), auto_set=False, enter_set=True)

    x_label = Str('displacement [mm]', enter_set=True, auto_set=False)
    y_label = Str('force [N]', enter_set=True, auto_set=False)

    tau = Property(Float, depends_on='qf', label='tau')

    def _get_tau(self):
        return self.qf / (self.p)

    Pu = Property(Float, depends_on='fu, rf, phi', label='Pu')

    def _get_Pu(self):
        return self.fu * self.Af * cos(self.phi)

    w = Property(Float, depends_on='rf, k, Ef', label='w')

    def _get_w(self):
        return sqrt(self.k / self.Ef / self.Af)

    Af = Property(Float, depends_on='d')

    def _get_Af(self):
        return pi * self.d**2 / 4.

    def get_P(self, a, qf, qy, L):
        return (self.qf * a + qy / self.w * tanh(self.w * (L - a)))

    def get_u(self, P, a):
        ''' takes a- and P-array and returns u-array '''
        Ef = self.Ef
        A = self.Af
        w = self.w

        u = (P - self.qf * a) / Ef / A / w / self.get_clamp(a) + \
            (P - .5 * self.qf * a) / Ef / A * a + P * self.l / A / Ef
        return u

    def u_L0_residuum(self, L0, qf, L, Ef, A, l):

        a = linspace(0, L - L / 1e10, self.accuracy)
        P_deb = self.get_P(a)
        u_deb = self.get_u(P_deb, a, Ef, A)
        idxmax = argmax(u_deb)
        u_max = u_deb[idxmax]

        P = qf * L0 * (1 + self.beta * (L - L0) / (2 * self.rf))
        delta_u = P * L0 / (2. * Ef * A)
        delta_free_l = (l + L - L0) * P / (Ef * A)
        delta_l = L - L0
        u = delta_u + delta_free_l + delta_l
        return u_max - u

    def continuous_function(self, u, E_f, L_f, d, qy, qf, k, z, phi, f):
        # returns the u and P array for a fiber with infinite strength

        L = L_f
        Ef = E_f
        A = self.Af

        a = linspace(0, L - L / 1e10, self.accuracy)
        # P-u diagram including snap back
        P_deb_full = self.get_P(a, qf, qy, L)
        u_deb_full = self.get_u(P_deb_full, a)
        idxmax = argmax(u_deb_full)
        # P-u diagram snap back cutted
        u_deb = u_deb_full[0:idxmax + 1]
        P_deb = P_deb_full[0:idxmax + 1]

        # pull-out stage
        # L0 is the embedded length of a pure frictional pull-out that
        # corresponds to the displacement at the end of the debonding stage
        L0 = brentq(self.u_L0_residuum, 1e-12, 2 * L)
        # if L0 is not in interval (0,L), the load drops to zero
        if round(L, 7) >= round(L0, 7) >= 0:
            lp = linspace(L0, 0, 100)
            P_pull = qf * lp * (1 + self.beta * (L - lp) / (2 * self.rf))
            # displacement corresponding to the actual embedded length
            delta_u = P_pull * lp / (2. * Ef * A)
            # displacement corresponding to the actual free length
            delta_free_l = (self.l + L - lp) * P_pull / (Ef * A)
            # displacement corresponding to the free length increment
            delta_l = L - lp
            u_pull = delta_u + delta_free_l + delta_l
            return u_deb, u_pull, P_deb, P_pull
        else:
            u_pull = u_deb[-1]
            P_pull = 0
            return u_deb, u_pull, P_deb, P_pull

    def value_finite(self, u, E_f, L_f, d, qy, qf, k, z, phi, f):
        ''' returns the final x and y arrays for finite embedded length '''

        Pu = self.Pu
        A = self.Af
        w = self.w
        Ef = E_f

        values = self.continuous_finite(u, E_f, L_f, d, qy, qf, k, z, phi, f)

        u_deb = values[0]
        u_pull = values[1]
        P_deb = values[2] * e**(self.f * self.phi)
        P_pull = values[3] * e**(self.f * self.phi)

        # if the pull-out force is lower than the breaking force
        if all(hstack((P_deb, P_pull)) < Pu):
            xdata = hstack((0, u_deb, u_pull))
            ydata = hstack((0, P_deb, P_pull))
            # if the breaking force is reached
        else:
            # max force reached in the pull-out stage
            if all(P_deb < Pu):
                max = where(P_pull > Pu)[0][0]
                xdata = hstack((0, u_deb, u_pull[:max + 1]))
                ydata = hstack((0, P_deb, P_pull[:max + 1]))
            else:
                # max force reached during debonding
                if P_deb[1] < Pu:
                    # max force reached after the debonding has started
                    a_lim = brentq(self.P_a_residuum, 1e-12, 1e3)
                    a = linspace(0, a_lim, 50)
                    P_deb = self.get_P(a)
                    u_deb = self.get_u(P_deb, a)
                    xdata = hstack((0, u_deb, u_deb[-1]))
                    ydata = hstack((0, P_deb, 0))
                else:
                    # max force reached before the debonding has started
                    u_max = Pu / (tanh(w * self.L) * Ef * A *
                                  w) + Pu * self.l / A / Ef
                    xdata = array([0, u_max, u_max])
                    ydata = array([0, Pu, 0])
        return xdata, ydata

    def __call__(self, u, E_f, L_f, d, qy, qf, k, z, phi, f):
        if self.include_fu == True:
            return self.value_finite(u, E_f, L_f, d, qy, qf, k, z, phi, f)
        else:
            values = self.continuous_finite(u, E_f, L_f, d, qy, qf, k, z, phi,
                                            f)
            u_deb = values[0]
            u_pull = values[1]
            P_deb = values[2] * e**(self.f * self.phi)
            P_pull = values[3] * e**(self.f * self.phi)
            xdata = hstack((0, u_deb, u_pull))
            ydata = hstack((0, P_deb, P_pull))
            interp_func = interp1d(xdata, ydata)
            return interp_func(u)

    traits_view = View(Item('E_f'),
                       Item('d'),
                       Item('f'),
                       Item('phi'),
                       Item('z'),
                       Item('tau_fr'),
                       resizable=True,
                       scrollable=True,
                       height=0.8,
                       width=0.8,
                       buttons=[OKButton, CancelButton])
Exemple #11
0
class ActionManager(HasTraits):
    """ Abstract base class for all action managers.

    An action manager contains a list of groups, with each group containing a
    list of items.

    There are currently three concrete sub-classes:-

    1) 'MenuBarManager'
    2) 'MenuManager'
    3) 'ToolBarManager'

    """

    #### 'ActionManager' interface ############################################

    # The Id of the default group.
    DEFAULT_GROUP = Constant('additions')

    # The action controller (if any) used to control how actions are performed.
    controller = Instance(ActionController)

    # Is the action manager enabled?
    enabled = Bool(True)

    # All of the contribution groups in the manager.
    groups = Property(List(Group))

    # The manager's unique identifier (if it has one).
    id = Str

    # Is the action manager visible?
    visible = Bool(True)

    #### Events ####

    # fixme: We probably need more granular events than this!
    changed = Event

    #### Private interface ####################################################

    # All of the contribution groups in the manager.
    _groups = List(Group)

    ###########################################################################
    # 'object' interface.
    ###########################################################################

    def __init__(self, *args, **traits):
        """ Creates a new menu manager. """

        # Base class constructor.
        super(ActionManager, self).__init__(**traits)

        # The last group in every manager is the group with Id 'additions'.
        #
        # fixme: The side-effect of this is to ensure that the 'additions'
        # group has been created.  Is the 'additions' group even a good idea?
        group = self._get_default_group()

        # Add all items to the manager.
        for arg in args:
            # We allow a group to be defined by simply specifying a string (its
            # Id).
            if isinstance(arg, basestring):
                # Create a group with the specified Id.
                arg = Group(id=arg)

            # If the item is a group then add it just before the default group
            # (ie. we always keep the default group as the last group in the
            # manager).
            if isinstance(arg, Group):
                self.insert(-1, arg)
                group = arg

            # Otherwise, the item is an action manager item so add it to the
            # current group.
            else:
                ##                 # If no group has been created then add one.  This is only
                ##                 # relevant when using the 'shorthand' way to define menus.
                ##                 if group is None:
                ##                     group = Group(id='__first__')
                ##                     self.insert(-1, group)

                group.append(arg)

        return

    ###########################################################################
    # 'ActionManager' interface.
    ###########################################################################

    #### Trait properties #####################################################

    def _get_groups(self):
        """ Returns the groups in the manager. """

        return self._groups[:]

    #### Trait change handlers ################################################

    def _enabled_changed(self, trait_name, old, new):
        """ Static trait change handler. """

        for group in self._groups:
            group.enabled = new

        return

    def _visible_changed(self, trait_name, old, new):
        """ Static trait change handler. """

        for group in self._groups:
            group.visible = new

        return

    #### Methods ##############################################################

    def append(self, item):
        """ Append an item to the manager.

        See the documentation for 'insert'.

        """

        return self.insert(len(self._groups), item)

    def destroy(self):
        """ Called when the manager is no longer required.

        By default this method simply calls 'destroy' on all of the manager's
        groups.

        """

        for group in self.groups:
            group.destroy()

        return

    def insert(self, index, item):
        """ Insert an item into the manager at the specified index.

        The item can be:-

        1) A 'Group' instance.

            In which case the group is inserted into the manager's list of
            groups.

        2) A string.

            In which case a 'Group' instance is created with that Id, and then
            inserted into the manager's list of groups.

        3) An 'ActionManagerItem' instance.

            In which case the item is inserted into the manager's default
            group.

        """

        # 1) The item is a 'Group' instance.
        if isinstance(item, Group):
            group = item

            # Insert the group into the manager.
            group.parent = self
            self._groups.insert(index, item)

        # 2) The item is a string.
        elif isinstance(item, basestring):
            # Create a group with that Id.
            group = Group(id=item)

            # Insert the group into the manager.
            group.parent = self
            self._groups.insert(index, group)

        # 3) The item is an 'ActionManagerItem' instance.
        else:
            # Find the default group.
            group = self._get_default_group()

            # Insert the item into the default group.
            group.insert(index, item)

        return group

    def find_group(self, id):
        """ Return the group with the specified Id.

        Return None if no such group exists.

        """

        for group in self._groups:
            if group.id == id:
                break

        else:
            group = None

        return group

    def find_item(self, path):
        """ Return the item found at the specified path.

        'path' is a '/' separated list of contribution Ids.

        Returns None if any component of the path is not found.

        """

        components = path.split('/')

        # If there is only one component, then the path is just an Id so look
        # it up in this manager.
        if len(components) > 0:
            item = self._find_item(components[0])

            if len(components) > 1 and item is not None:
                item = item.find_item('/'.join(components[1:]))

        else:
            item = None

        return item

    def walk(self, fn):
        """ Walk the manager applying a function at every item. """

        fn(self)

        for group in self._groups:
            self.walk_group(group, fn)

        return

    def walk_group(self, group, fn):
        """ Walk a group applying a function at every item. """

        fn(group)

        for item in group.items:
            if isinstance(item, Group):
                self.walk_group(item, fn)

            else:
                self.walk_item(item, fn)

        return

    def walk_item(self, item, fn):
        """ Walk an item (may be a sub-menu manager remember!). """

        if hasattr(item, 'groups'):
            item.walk(fn)

        else:
            fn(item)

        return

    ###########################################################################
    # Private interface.
    ###########################################################################

    def _get_default_group(self):
        """ Returns the manager's default group. """

        group = self.find_group(self.DEFAULT_GROUP)
        if group is None:
            group = Group(id=self.DEFAULT_GROUP)
            self.append(group)

        return group

    def _find_item(self, id):
        """ Returns the item with the specified Id.

        Returns None if no such item exists.

        """

        for group in self.groups:
            item = group.find(id)
            if item is not None:
                break

        else:
            item = None

        return item

    ###########################################################################
    # Debugging interface.
    ###########################################################################

    def dump(self, indent=''):
        """ Render a manager! """

        print indent, 'Manager', self.id
        indent += '  '

        for group in self._groups:
            self.render_group(group, indent)

        return

    def render_group(self, group, indent=''):
        """ Render a group! """

        print indent, 'Group', group.id
        indent += '    '

        for item in group.items:
            if isinstance(item, Group):
                print 'Surely, a group cannot contain another group!!!!'
                self.render_group(item, indent)

            else:
                self.render_item(item, indent)

        return

    def render_item(self, item, indent=''):
        """ Render an item! """

        if hasattr(item, 'groups'):
            item.dump(indent)

        else:
            print indent, 'Item', item.id

        return
Exemple #12
0
class ToolkitEditorFactory(EditorFactory):

    #---------------------------------------------------------------------------
    #  Trait definitions:
    #---------------------------------------------------------------------------

    cols = Range(1, 20)  # Number of columns when displayed as an enum
    auto_set = true  # Is user input set on every keystroke?
    enter_set = false  # Is user input set on enter key?
    low_label = Str  # Label for low end of range
    high_label = Str  # Label for high end of range
    is_float = true  # Is the range float (or int)?

    #---------------------------------------------------------------------------
    #  Performs any initialization needed after all constructor traits have
    #  been set:
    #---------------------------------------------------------------------------

    def init(self, handler=None):
        """ Performs any initialization needed after all constructor traits 
            have been set.
        """
        if handler is not None:
            if isinstance(handler, CTrait):
                handler = handler.handler
            self.low = handler.low
            self.high = handler.high

    #---------------------------------------------------------------------------
    #  Define the 'low' and 'high' traits:
    #---------------------------------------------------------------------------

    def _get_low(self):
        return self._low

    def _set_low(self, low):
        self._low = low
        self.is_float = (type(low) is float)
        if self.low_label == '':
            self.low_label = str(low)

    def _get_high(self):
        return self._high

    def _set_high(self, high):
        self._high = high
        self.is_float = (type(high) is float)
        if self.high_label == '':
            self.high_label = str(high)

    low = Property(_get_low, _set_low)
    high = Property(_get_high, _set_high)

    #---------------------------------------------------------------------------
    #  'Editor' factory methods:
    #---------------------------------------------------------------------------

    def simple_editor(self, ui, object, name, description, parent):
        if self.is_float or (abs(self.high - self.low) <= 100):
            return SimpleSliderEditor(parent,
                                      factory=self,
                                      ui=ui,
                                      object=object,
                                      name=name,
                                      description=description)
        return SimpleSpinEditor(parent,
                                factory=self,
                                ui=ui,
                                object=object,
                                name=name,
                                description=description)

    def custom_editor(self, ui, object, name, description, parent):
        if self.is_float or (abs(self.high - self.low) > 15):
            return self.simple_editor(ui, object, name, description, parent)

        if self._enum is None:
            import enum_editor
            self._enum = enum_editor.ToolkitEditorFactory(values=range(
                self.low, self.high + 1),
                                                          cols=self.cols)
        return self._enum.custom_editor(ui, object, name, description, parent)

    def text_editor(self, ui, object, name, description, parent):
        return RangeTextEditor(parent,
                               factory=self,
                               ui=ui,
                               object=object,
                               name=name,
                               description=description)
Exemple #13
0
class Edge(HasTraits):
    """ Defines a graph edge. """

    #--------------------------------------------------------------------------
    #  Trait definitions:
    #--------------------------------------------------------------------------

    # Tail/from/source/start node.
    tail_node = Instance(Node, allow_none=False)

    # Head/to/target/end node.
    head_node = Instance(Node, allow_none=False)

    # String identifier (TreeNode label).
    name = Property(
        Str,
        depends_on=["tail_node", "tail_node.ID", "head_node", "head_node.ID"])

    # Connection string used in string output.
    conn = Enum("->", "--")

    # Nodes from which the tail and head nodes may be selected.
    _nodes = List(Instance(Node))  # GUI specific.

    #--------------------------------------------------------------------------
    #  Xdot trait definitions:
    #--------------------------------------------------------------------------

    # For a given graph object, one will typically a draw directive before the
    # label directive. For example, for a node, one would first use the
    # commands in _draw_ followed by the commands in _ldraw_.
    _draw_ = Str(desc="xdot drawing directive", label="draw")
    _ldraw_ = Str(desc="xdot label drawing directive", label="ldraw")

    _hdraw_ = Str(desc="edge head arrowhead drawing directive.", label="hdraw")
    _tdraw_ = Str(desc="edge tail arrowhead drawing directive.", label="tdraw")
    _hldraw_ = Str(desc="edge head label drawing directive.", label="hldraw")
    _tldraw_ = Str(desc="edge tail label drawing directive.", label="tldraw")

    #--------------------------------------------------------------------------
    #  Enable trait definitions:
    #--------------------------------------------------------------------------

    # Container of drawing components, typically the edge spline.
    drawing = Instance(Container)

    # Container of label components.
    label_drawing = Instance(Container)

    # Container of head arrow components.
    arrowhead_drawing = Instance(Container)

    # Container of tail arrow components.
    arrowtail_drawing = Instance(Container)

    # Container of head arrow label components.
    arrowhead_label_drawing = Instance(Container)

    # Container of tail arrow label components.
    arrowtail_label_drawing = Instance(Container)

    # Container for the drawing, label, arrow and arrow label components.
    component = Instance(Container, desc="container of graph components.")

    # A view into a sub-region of the canvas.
    vp = Instance(Viewport, desc="a view of a sub-region of the canvas")

    # Use Graphviz to arrange all graph components.
    arrange = Button("Arrange All")

    #--------------------------------------------------------------------------
    #  Dot trait definitions:
    #--------------------------------------------------------------------------

    # Style of arrowhead on the head node of an edge.
    # See also the <html:a rel="attr">dir</html:a> attribute,
    # and the <html:a rel="note">undirected</html:a> note.
    arrowhead = arrow_trait

    # Multiplicative scale factor for arrowheads.
    arrowsize = Float(1.0,
                      desc="multiplicative scale factor for arrowheads",
                      label="Arrow size",
                      graphviz=True)

    # Style of arrowhead on the tail node of an edge.
    # See also the <html:a rel="attr">dir</html:a> attribute,
    # and the <html:a rel="note">undirected</html:a> note.
    arrowtail = arrow_trait

    # Basic drawing color for graphics, not text. For the latter, use the
    # <html:a rel="attr">fontcolor</html:a> attribute.
    #
    # For edges, the value
    # can either be a single <html:a rel="type">color</html:a> or a <html:a rel="type">colorList</html:a>.
    # In the latter case, the edge is drawn using parallel splines or lines,
    # one for each color in the list, in the order given.
    # The head arrow, if any, is drawn using the first color in the list,
    # and the tail arrow, if any, the second color. This supports the common
    # case of drawing opposing edges, but using parallel splines instead of
    # separately routed multiedges.
    color = color_trait

    # This attribute specifies a color scheme namespace. If defined, it specifies
    # the context for interpreting color names. In particular, if a
    # <html:a rel="type">color</html:a> value has form <html:code>xxx</html:code> or <html:code>//xxx</html:code>,
    # then the color <html:code>xxx</html:code> will be evaluated according to the current color scheme.
    # If no color scheme is set, the standard X11 naming is used.
    # For example, if <html:code>colorscheme=bugn9</html:code>, then <html:code>color=7</html:code>
    # is interpreted as <html:code>/bugn9/7</html:code>.
    colorscheme = color_scheme_trait

    # Comments are inserted into output. Device-dependent.
    comment = comment_trait

    # If <html:span class="val">false</html:span>, the edge is not used in
    # ranking the nodes.
    constraint = Bool(True,
                      desc="if edge is used in ranking the nodes",
                      graphviz=True)

    # If <html:span class="val">true</html:span>, attach edge label to edge by a 2-segment
    # polyline, underlining the label, then going to the closest point of spline.
    decorate = Bool(
        False,
        desc="to attach edge label to edge by a 2-segment "
        "polyline, underlining the label, then going to the closest point of "
        "spline",
        graphviz=True)

    # Set edge type for drawing arrowheads. This indicates which ends of the
    # edge should be decorated with an arrowhead. The actual style of the
    # arrowhead can be specified using the <html:a rel="attr">arrowhead</html:a>
    # and <html:a rel="attr">arrowtail</html:a> attributes.
    # See <html:a rel="note">undirected</html:a>.
    dir = Enum("forward",
               "back",
               "both",
               "none",
               label="Direction",
               desc="edge type for drawing arrowheads",
               graphviz=True)

    # Synonym for <html:a rel="attr">edgeURL</html:a>.
    #    edgehref = Alias("edgeURL", desc="synonym for edgeURL")
    edgehref = Synced(sync_to="edgeURL", graphviz=True)

    # If the edge has a URL or edgeURL  attribute, this attribute determines
    # which window of the browser is used for the URL attached to the non-label
    # part of the edge. Setting it to "_graphviz" will open a new window if it
    # doesn't already exist, or reuse it if it does. If undefined, the value of
    # the target is used.
    edgetarget = Str("",
                     desc="which window of the browser is used for the "
                     "URL attached to the non-label part of the edge",
                     label="Edge target",
                     graphviz=True)

    # Tooltip annotation attached to the non-label part of an edge.
    # This is used only if the edge has a <html:a rel="attr">URL</html:a>
    # or <html:a rel="attr">edgeURL</html:a> attribute.
    edgetooltip = Str("",
                      desc="annotation attached to the non-label part of "
                      "an edge",
                      label="Edge tooltip",
                      graphviz=True)
    #    edgetooltip = EscString

    # If <html:a rel="attr">edgeURL</html:a> is defined, this is the link used for the non-label
    # parts of an edge. This value overrides any <html:a rel="attr">URL</html:a>
    # defined for the edge.
    # Also, this value is used near the head or tail node unless overridden
    # by a <html:a rel="attr">headURL</html:a> or <html:a rel="attr">tailURL</html:a> value,
    # respectively.
    # See <html:a rel="note">undirected</html:a>.
    edgeURL = Str("",
                  desc="link used for the non-label parts of an edge",
                  label="Edge URL",
                  graphviz=True)  #LabelStr

    # Color used for text.
    fontcolor = fontcolor_trait

    # Font used for text. This very much depends on the output format and, for
    # non-bitmap output such as PostScript or SVG, the availability of the font
    # when the graph is displayed or printed. As such, it is best to rely on
    # font faces that are generally available, such as Times-Roman, Helvetica or
    # Courier.
    #
    # If Graphviz was built using the
    # <html:a href="http://pdx.freedesktop.org/~fontconfig/fontconfig-user.html">fontconfig library</html:a>, the latter library
    # will be used to search for the font. However, if the <html:a rel="attr">fontname</html:a> string
    # contains a slash character "/", it is treated as a pathname for the font
    # file, though font lookup will append the usual font suffixes.
    #
    # If Graphviz does not use fontconfig, <html:a rel="attr">fontname</html:a> will be
    # considered the name of a Type 1 or True Type font file.
    # If you specify <html:code>fontname=schlbk</html:code>, the tool will look for a
    # file named  <html:code>schlbk.ttf</html:code> or <html:code>schlbk.pfa</html:code> or <html:code>schlbk.pfb</html:code>
    # in one of the directories specified by
    # the <html:a rel="attr">fontpath</html:a> attribute.
    # The lookup does support various aliases for the common fonts.
    fontname = fontname_trait

    # Font size, in <html:a rel="note">points</html:a>, used for text.
    fontsize = fontsize_trait

    # If <html:span class="val">true</html:span>, the head of an edge is clipped to the boundary of the head node;
    # otherwise, the end of the edge goes to the center of the node, or the
    # center of a port, if applicable.
    headclip = Bool(True,
                    desc="head of an edge to be clipped to the boundary "
                    "of the head node",
                    label="Head clip",
                    graphviz=True)

    # Synonym for <html:a rel="attr">headURL</html:a>.
    headhref = Alias("headURL", desc="synonym for headURL", graphviz=True)

    # Text label to be placed near head of edge.
    # See <html:a rel="note">undirected</html:a>.
    headlabel = Str("",
                    desc="text label to be placed near head of edge",
                    label="Head label",
                    graphviz=True)

    headport = port_pos_trait

    # If the edge has a headURL, this attribute determines which window of the
    # browser is used for the URL. Setting it to "_graphviz" will open a new
    # window if it doesn't already exist, or reuse it if it does. If undefined,
    # the value of the target is used.
    headtarget = Str(desc="which window of the browser is used for the URL",
                     label="Head target",
                     graphviz=True)

    # Tooltip annotation attached to the head of an edge. This is used only
    # if the edge has a <html:a rel="attr">headURL</html:a> attribute.
    headtooltip = Str("",
                      desc="tooltip annotation attached to the head of an "
                      "edge",
                      label="Head tooltip",
                      graphviz=True)

    # If <html:a rel="attr">headURL</html:a> is defined, it is
    # output as part of the head label of the edge.
    # Also, this value is used near the head node, overriding any
    # <html:a rel="attr">URL</html:a> value.
    # See <html:a rel="note">undirected</html:a>.
    headURL = Str("",
                  desc="output as part of the head label of the edge",
                  label="Head URL",
                  graphviz=True)

    # Synonym for <html:a rel="attr">URL</html:a>.
    href = Alias("URL", desc="synonym for URL", graphviz=True)

    # Text label attached to objects.
    # If a node's <html:a rel="attr">shape</html:a> is record, then the label can
    # have a <html:a href="http://www.graphviz.org/doc/info/shapes.html#record">special format</html:a>
    # which describes the record layout.
    label = label_trait

    # This, along with <html:a rel="attr">labeldistance</html:a>, determine
    # where the
    # headlabel (taillabel) are placed with respect to the head (tail)
    # in polar coordinates. The origin in the coordinate system is
    # the point where the edge touches the node. The ray of 0 degrees
    # goes from the origin back along the edge, parallel to the edge
    # at the origin.
    #
    # The angle, in degrees, specifies the rotation from the 0 degree ray,
    # with positive angles moving counterclockwise and negative angles
    # moving clockwise.
    labelangle = Float(
        -25.0,
        desc=", along with labeldistance, where the "
        "headlabel (taillabel) are placed with respect to the head (tail)",
        label="Label angle",
        graphviz=True)

    # Multiplicative scaling factor adjusting the distance that
    # the headlabel (taillabel) is from the head (tail) node.
    # The default distance is 10 points. See <html:a rel="attr">labelangle</html:a>
    # for more details.
    labeldistance = Float(
        1.0,
        desc="multiplicative scaling factor adjusting "
        "the distance that the headlabel (taillabel) is from the head (tail) "
        "node",
        label="Label distance",
        graphviz=True)

    # If true, allows edge labels to be less constrained in position. In
    # particular, it may appear on top of other edges.
    labelfloat = Bool(False,
                      desc="edge labels to be less constrained in "
                      "position",
                      label="Label float",
                      graphviz=True)

    # Color used for headlabel and taillabel.
    # If not set, defaults to edge's fontcolor.
    labelfontcolor = Color("black",
                           desc="color used for headlabel and "
                           "taillabel",
                           label="Label font color",
                           graphviz=True)

    # Font used for headlabel and taillabel.
    # If not set, defaults to edge's fontname.
    labelfontname = Font("Times-Roman",
                         desc="Font used for headlabel and "
                         "taillabel",
                         label="Label font name",
                         graphviz=True)

    # Font size, in <html:a rel="note">points</html:a>, used for headlabel and taillabel.
    # If not set, defaults to edge's fontsize.
    labelfontsize = Float(14.0,
                          desc="Font size, in points, used for "
                          "headlabel and taillabel",
                          label="label_font_size",
                          graphviz=True)

    # Synonym for <html:a rel="attr">labelURL</html:a>.
    labelhref = Alias("labelURL", desc="synonym for labelURL", graphviz=True)

    # If the edge has a URL or labelURL  attribute, this attribute determines
    # which window of the browser is used for the URL attached to the label.
    # Setting it to "_graphviz" will open a new window if it doesn't already
    # exist, or reuse it if it does. If undefined, the value of the target is
    # used.
    labeltarget = Str("",
                      desc="which window of the browser is used for the "
                      "URL attached to the label",
                      label="Label target",
                      graphviz=True)

    # Tooltip annotation attached to label of an edge.
    # This is used only if the edge has a <html:a rel="attr">URL</html:a>
    # or <html:a rel="attr">labelURL</html:a> attribute.
    labeltooltip = Str("",
                       desc="tooltip annotation attached to label of an "
                       "edge",
                       label="Label tooltip",
                       graphviz=True)

    # If <html:a rel="attr">labelURL</html:a> is defined, this is the link used for the label
    # of an edge. This value overrides any <html:a rel="attr">URL</html:a>
    # defined for the edge.
    labelURL = Str(desc="link used for the label of an edge", graphviz=True)

    # Specifies layers in which the node or edge is present.
    layer = layer_trait

    # Preferred edge length, in inches.
    len = Float(1.0, desc="preferred edge length, in inches",
                graphviz=True)  #0.3(fdp)

    # Logical head of an edge. When compound is true, if lhead is defined and
    # is the name of a cluster containing the real head, the edge is clipped to
    # the boundary of the cluster.
    lhead = Str(desc="Logical head of an edge", graphviz=True)

    # Label position, in points. The position indicates the center of the label.
    lp = point_trait

    # Logical tail of an edge. When compound is true, if ltail is defined and
    # is the name of a cluster containing the real tail, the edge is clipped to
    # the boundary of the cluster.
    ltail = Str(desc="logical tail of an edge", graphviz=True)

    # Minimum edge length (rank difference between head and tail).
    minlen = Int(1, desc="minimum edge length", graphviz=True)

    # By default, the justification of multi-line labels is done within the
    # largest context that makes sense. Thus, in the label of a polygonal node,
    # a left-justified line will align with the left side of the node (shifted
    # by the prescribed margin). In record nodes, left-justified line will line
    # up with the left side of the enclosing column of fields. If nojustify is
    # "true", multi-line labels will be justified in the context of itself. For
    # example, if the attribute is set, the first label line is long, and the
    # second is shorter and left-justified, the second will align with the
    # left-most character in the first line, regardless of how large the node
    # might be.
    nojustify = nojustify_trait

    # Position of node, or spline control points.
    # For nodes, the position indicates the center of the node.
    # On output, the coordinates are in <html:a href="#points">points</html:a>.
    #
    # In neato and fdp, pos can be used to set the initial position of a node.
    # By default, the coordinates are assumed to be in inches. However, the
    # <html:a href="http://www.graphviz.org/doc/info/command.html#d:s">-s</html:a> command line flag can be used to specify
    # different units.
    #
    # When the <html:a href="http://www.graphviz.org/doc/info/command.html#d:n">-n</html:a> command line flag is used with
    # neato, it is assumed the positions have been set by one of the layout
    # programs, and are therefore in points. Thus, <html:code>neato -n</html:code> can accept
    # input correctly without requiring a <html:code>-s</html:code> flag and, in fact,
    # ignores any such flag.
    pos = List(Tuple(Float, Float), desc="spline control points")

    # Edges with the same head and the same <html:a rel="attr">samehead</html:a> value are aimed
    # at the same point on the head.
    # See <html:a rel="note">undirected</html:a>.
    samehead = Str("",
                   desc="dges with the same head and the same samehead "
                   "value are aimed at the same point on the head",
                   graphviz=True)

    # Edges with the same tail and the same <html:a rel="attr">sametail</html:a> value are aimed
    # at the same point on the tail.
    # See <html:a rel="note">undirected</html:a>.
    sametail = Str("",
                   desc="edges with the same tail and the same sametail "
                   "value are aimed at the same point on the tail",
                   graphviz=True)

    # Print guide boxes in PostScript at the beginning of
    # routesplines if 1, or at the end if 2. (Debugging)
    showboxes = showboxes_trait

    # Set style for node or edge. For cluster subgraph, if "filled", the
    # cluster box's background is filled.
    style = ListStr(desc="style for node or edge", graphviz=True)

    # If <html:span class="val">true</html:span>, the tail of an edge is clipped to the boundary of the tail node;
    # otherwise, the end of the edge goes to the center of the node, or the
    # center of a port, if applicable.
    tailclip = Bool(True,
                    desc="tail of an edge to be clipped to the boundary "
                    "of the tail node",
                    graphviz=True)

    # Synonym for <html:a rel="attr">tailURL</html:a>.
    tailhref = Alias("tailURL", desc="synonym for tailURL", graphviz=True)

    # Text label to be placed near tail of edge.
    # See <html:a rel="note">undirected</html:a>.
    taillabel = Str(desc="text label to be placed near tail of edge",
                    graphviz=True)

    # Indicates where on the tail node to attach the tail of the edge.
    tailport = port_pos_trait

    # If the edge has a tailURL, this attribute determines which window of the
    # browser is used for the URL. Setting it to "_graphviz" will open a new
    # window if it doesn't already exist, or reuse it if it does. If undefined,
    # the value of the target is used.
    tailtarget = Str(desc="which window of the browser is used for the URL",
                     graphviz=True)

    # Tooltip annotation attached to the tail of an edge. This is used only
    # if the edge has a <html:a rel="attr">tailURL</html:a> attribute.
    tailtooltip = Str("",
                      desc="tooltip annotation attached to the tail of an "
                      "edge",
                      label="Tail tooltip",
                      graphviz=True)

    # If <html:a rel="attr">tailURL</html:a> is defined, it is
    # output as part of the tail label of the edge.
    # Also, this value is used near the tail node, overriding any
    # <html:a rel="attr">URL</html:a> value.
    # See <html:a rel="note">undirected</html:a>.
    tailURL = Str("",
                  desc="output as part of the tail label of the edge",
                  label="Tail URL",
                  graphviz=True)

    # If the object has a URL, this attribute determines which window
    # of the browser is used for the URL.
    # See <html:a href="http://www.w3.org/TR/html401/present/frames.html#adef-target">W3C documentation</html:a>.
    target = target_trait

    # Tooltip annotation attached to the node or edge. If unset, Graphviz
    # will use the object's <html:a rel="attr">label</html:a> if defined.
    # Note that if the label is a record specification or an HTML-like
    # label, the resulting tooltip may be unhelpful. In this case, if
    # tooltips will be generated, the user should set a <html:tt>tooltip</html:tt>
    # attribute explicitly.
    tooltip = tooltip_trait

    # Hyperlinks incorporated into device-dependent output.
    # At present, used in ps2, cmap, i*map and svg formats.
    # For all these formats, URLs can be attached to nodes, edges and
    # clusters. URL attributes can also be attached to the root graph in ps2,
    # cmap and i*map formats. This serves as the base URL for relative URLs in the
    # former, and as the default image map file in the latter.
    #
    # For svg, cmapx and imap output, the active area for a node is its
    # visible image.
    # For example, an unfilled node with no drawn boundary will only be active on its label.
    # For other output, the active area is its bounding box.
    # The active area for a cluster is its bounding box.
    # For edges, the active areas are small circles where the edge contacts its head
    # and tail nodes. In addition, for svg, cmapx and imap, the active area
    # includes a thin polygon approximating the edge. The circles may
    # overlap the related node, and the edge URL dominates.
    # If the edge has a label, this will also be active.
    # Finally, if the edge has a head or tail label, this will also be active.
    #
    # Note that, for edges, the attributes <html:a rel="attr">headURL</html:a>,
    # <html:a rel="attr">tailURL</html:a>, <html:a rel="attr">labelURL</html:a> and
    # <html:a rel="attr">edgeURL</html:a> allow control of various parts of an
    # edge. Also note that, if active areas of two edges overlap, it is unspecified
    # which area dominates.
    URL = url_trait

    # Weight of edge. In dot, the heavier the weight, the shorter, straighter
    # and more vertical the edge is.
    weight = Float(1.0, desc="weight of edge", graphviz=True)

    #--------------------------------------------------------------------------
    #  Views:
    #--------------------------------------------------------------------------

    traits_view = View(VGroup(
        Group(
            Item(name="vp",
                 editor=ComponentEditor(height=100),
                 show_label=False,
                 id=".component"), Item("arrange", show_label=False)),
        Tabbed(
            Group(Item(name="tail_node",
                       editor=InstanceEditor(name="_nodes", editable=False)),
                  Item(name="head_node",
                       editor=InstanceEditor(name="_nodes", editable=False)), [
                           "style", "layer", "color", "colorscheme", "dir",
                           "arrowsize", "constraint", "decorate", "showboxes",
                           "tooltip", "edgetooltip", "edgetarget", "target",
                           "comment"
                       ],
                  label="Edge"),
            Group([
                "label", "fontname", "fontsize", "fontcolor", "nojustify",
                "labeltarget", "labelfloat", "labelfontsize", "labeltooltip",
                "labelangle", "lp", "labelURL", "labelfontname",
                "labeldistance", "labelfontcolor", "labelhref"
            ],
                  label="Label"),
            Group(["minlen", "weight", "len", "pos"], label="Dimension"),
            Group([
                "arrowhead", "samehead", "headURL", "headtooltip", "headclip",
                "headport", "headlabel", "headtarget", "lhead", "headhref"
            ],
                  label="Head"),
            Group([
                "arrowtail", "tailtarget", "tailhref", "ltail", "sametail",
                "tailport", "taillabel", "tailtooltip", "tailURL", "tailclip"
            ],
                  label="Tail"),
            Group(["URL", "href", "edgeURL", "edgehref"], label="URL"),
            Group([
                "_draw_", "_ldraw_", "_hdraw_", "_tdraw_", "_hldraw_",
                "_tldraw_"
            ],
                  label="Xdot"),
            dock="tab"),
        layout="split",
        id=".splitter"),
                       title="Edge",
                       id="godot.edge",
                       buttons=["OK", "Cancel", "Help"],
                       resizable=True)

    #--------------------------------------------------------------------------
    #  "object" interface:
    #--------------------------------------------------------------------------

    def __init__(self,
                 tailnode_or_ID,
                 headnode_or_ID,
                 directed=False,
                 **traits):
        """ Initialises a new Edge instance.
        """
        if not isinstance(tailnode_or_ID, Node):
            tailnodeID = str(tailnode_or_ID)
            tail_node = Node(tailnodeID)
        else:
            tail_node = tailnode_or_ID

        if not isinstance(headnode_or_ID, Node):
            headnodeID = str(headnode_or_ID)
            head_node = Node(headnodeID)
        else:
            head_node = headnode_or_ID

        self.tail_node = tail_node
        self.head_node = head_node

        if directed:
            self.conn = "->"
        else:
            self.conn = "--"

        super(Edge, self).__init__(**traits)

    def __str__(self):
        """ Returns a string representation of the edge.
        """
        attrs = []
        # Traits to be included in string output have 'graphviz' metadata.
        for trait_name, trait in self.traits(graphviz=True).iteritems():
            # Get the value of the trait for comparison with the default.
            value = getattr(self, trait_name)

            # Only print attribute value pairs if not defaulted.
            # FIXME: Alias/Synced traits default to None.
            if (value != trait.default) and (trait.default is not None):
                # Add quotes to the value if necessary.
                if isinstance(value, basestring):
                    valstr = '"%s"' % value
                else:
                    valstr = str(value)

                attrs.append('%s=%s' % (trait_name, valstr))

        if attrs:
            attrstr = " [%s]" % ", ".join(attrs)
        else:
            attrstr = ""

        edge_str = "%s%s %s %s%s%s;" % (self.tail_node.ID, self.tailport,
                                        self.conn, self.head_node.ID,
                                        self.headport, attrstr)
        return edge_str

    #--------------------------------------------------------------------------
    #  Trait initialisers:
    #--------------------------------------------------------------------------

    def _component_default(self):
        """ Trait initialiser.
        """
        component = Container(auto_size=True, bgcolor="green")
        #        component.tools.append( MoveTool(component) )
        #        component.tools.append( TraitsTool(component) )
        return component

    def _vp_default(self):
        """ Trait initialiser.
        """
        vp = Viewport(component=self.component)
        vp.enable_zoom = True
        vp.tools.append(ViewportPanTool(vp))
        return vp

    #--------------------------------------------------------------------------
    #  Property getters:
    #--------------------------------------------------------------------------

    def _get_name(self):
        """ Property getter.
        """
        if (self.tail_node is not None) and (self.head_node is not None):
            return "%s %s %s" % (self.tail_node.ID, self.conn,
                                 self.head_node.ID)
        else:
            return "Edge"

    #--------------------------------------------------------------------------
    #  Event handlers:
    #--------------------------------------------------------------------------

    @on_trait_change("arrange")
    def arrange_all(self):
        """ Arrange the components of the node using Graphviz.
        """
        # FIXME: Circular reference avoidance.
        import godot.dot_data_parser
        import godot.graph

        graph = godot.graph.Graph(ID="g", directed=True)
        self.conn = "->"
        graph.edges.append(self)

        xdot_data = graph.create(format="xdot")
        #        print "XDOT DATA:", xdot_data

        parser = godot.dot_data_parser.GodotDataParser()
        ndata = xdot_data.replace('\\\n', '')
        tokens = parser.dotparser.parseString(ndata)[0]

        for element in tokens[3]:
            cmd = element[0]
            if cmd == "add_edge":
                cmd, src, dest, opts = element
                self.set(**opts)

#    @on_trait_change("_draw_,_hdraw_")

    def _parse_xdot_directive(self, name, new):
        """ Handles parsing Xdot drawing directives.
        """
        parser = XdotAttrParser()
        components = parser.parse_xdot_data(new)

        # The absolute coordinate of the drawing container wrt graph origin.
        x1 = min([c.x for c in components])
        y1 = min([c.y for c in components])

        print "X1/Y1:", name, x1, y1

        # Components are positioned relative to their container. This
        # function positions the bottom-left corner of the components at
        # their origin rather than relative to the graph.
        #        move_to_origin( components )

        for c in components:
            if isinstance(c, Ellipse):
                component.x_origin -= x1
                component.y_origin -= y1
#                c.position = [ c.x - x1, c.y - y1 ]

            elif isinstance(c, (Polygon, BSpline)):
                print "Points:", c.points
                c.points = [(t[0] - x1, t[1] - y1) for t in c.points]
                print "Points:", c.points

            elif isinstance(c, Text):
                #                font = str_to_font( str(c.pen.font) )
                c.text_x, c.text_y = c.x - x1, c.y - y1

        container = Container(auto_size=True,
                              position=[x1, y1],
                              bgcolor="yellow")

        container.add(*components)

        if name == "_draw_":
            self.drawing = container
        elif name == "_hdraw_":
            self.arrowhead_drawing = container
        else:
            raise

    @on_trait_change("drawing,arrowhead_drawing")
    def _on_drawing(self, object, name, old, new):
        """ Handles the containers of drawing components being set.
        """
        attrs = ["drawing", "arrowhead_drawing"]

        others = [getattr(self, a) for a in attrs \
            if (a != name) and (getattr(self, a) is not None)]

        x, y = self.component.position
        print "POS:", x, y, self.component.position

        abs_x = [d.x + x for d in others]
        abs_y = [d.y + y for d in others]

        print "ABS:", abs_x, abs_y

        # Assume that he new drawing is positioned relative to graph origin.
        x1 = min(abs_x + [new.x])
        y1 = min(abs_y + [new.y])

        print "DRAW:", new.position
        new.position = [new.x - x1, new.y - y1]
        print "DRAW:", new.position

        #        for i, b in enumerate( others ):
        #            self.drawing.position = [100, 100]
        #            self.drawing.request_redraw()
        #            print "OTHER:", b.position, abs_x[i] - x1
        #            b.position = [ abs_x[i] - x1, abs_y[i] - y1 ]
        #            b.x = 50
        #            b.y = 50
        #            print "OTHER:", b.position, abs_x[i], x1

        #        for attr in attrs:
        #            if attr != name:
        #                if getattr(self, attr) is not None:
        #                    drawing = getattr(self, attr)
        #                    drawing.position = [50, 50]

        if old is not None:
            self.component.remove(old)
        if new is not None:
            self.component.add(new)

        print "POS NEW:", self.component.position
        self.component.position = [x1, y1]
        print "POS NEW:", self.component.position
        self.component.request_redraw()
        print "POS NEW:", self.component.position
Exemple #14
0
class NodeTree(Tree):
    """ A tree control with extensible node types. """

    #### 'Tree' interface #####################################################

    # The model that provides the data for the tree.
    model = Instance(NodeTreeModel, ())

    #### 'NodeTree' interface #################################################

    # The node manager looks after all node types.
    node_manager = Property(Instance(NodeManager))

    # The node types in the tree.
    node_types = Property(List(NodeType))

    ###########################################################################
    # 'NodeTree' interface.
    ###########################################################################

    #### Properties ###########################################################

    # node_manager
    def _get_node_manager(self):
        """ Returns the root node of the tree. """

        return self.model.node_manager

    def _set_node_manager(self, node_manager):
        """ Sets the root node of the tree. """

        self.model.node_manager = node_manager

        return

    # node_types
    def _get_node_types(self):
        """ Returns the node types in the tree. """

        return self.model.node_manager.node_types

    def _set_node_types(self, node_types):
        """ Sets the node types in the tree. """

        self.model.node_manager.node_types = node_types

        return

    ###########################################################################
    # 'Tree' interface.
    ###########################################################################

    #### Trait event handlers #################################################

    def _node_activated_changed(self, obj):
        """ Called when a node has been activated (i.e., double-clicked). """

        default_action = self.model.get_default_action(obj)
        if default_action is not None:
            self._perform_default_action(default_action, obj)

        return

    def _node_right_clicked_changed(self, (obj, point)):
        """ Called when the right mouse button is clicked on the tree. """

        # Add the node that the right-click occurred on to the selection.
        self.select(obj)

        # fixme: This is a hack to allow us to attach the node that the
        # right-clicked occurred on to the action event.
        self._context = obj

        # Ask the model for the node's context menu.
        menu_manager = self.model.get_context_menu(obj)
        if menu_manager is not None:
            self._popup_menu(menu_manager, obj, point)

        return
Exemple #15
0
class TypeManager(HasTraits):
    """ A type manager.

    The type manager allows for objects to be created/adapted to a particular
    type.

    """

    #### 'TypeManager' interface ##############################################

    # The adapter manager looks after errr, all adapters.
    adapter_manager = Property(Instance(AdapterManager))

    # The type manager's globally unique identifier (only required if you have
    # more than one type manager of course!).
    id = Str

    # The parent type manager.
    #
    # By default this is None, but you can use it to set up a hierarchy of
    # type managers. If a type manager fails to adapt or create an object of
    # some target class then it will give its parent a chance to do so.
    parent = Instance('TypeManager')

    # The type system used by the manager (it determines 'is_a' relationships
    # and type MROs etc).
    #
    # By default we use standard Python semantics.
    type_system = Instance(AbstractTypeSystem, PythonTypeSystem())

    #### Private interface ####################################################

    # The adapter manager looks after errr, all adapters.
    _adapter_manager = Instance(AdapterManager)

    ###########################################################################
    # 'TypeManager' interface.
    ###########################################################################

    #### Properties ###########################################################

    def _get_adapter_manager(self):
        """ Returns the adapter manager. """

        return self._adapter_manager

    #### Methods ##############################################################

    def object_as(self, obj, target_class, *args, **kw):
        """ Adapts or creates an object of the target class.

        Returns None if no appropriate adapter or factory is available.

        """

        # If the object is already an instance of the target class then we
        # simply return it.
        if self.type_system.is_a(obj, target_class):
            result = obj

        # If the object is a factory that creates instances of the target class
        # then ask it to produce one.
        elif self._is_factory_for(obj, target_class, *args, **kw):
            result = obj.create(target_class, *args, **kw)

        # Otherwise, see if the object can be adapted to the target class.
        else:
            result = self._adapter_manager.adapt(obj, target_class, *args,
                                                 **kw)

        # If this type manager couldn't do the job, then give its parent a go!
        if result is None and self.parent is not None:
            result = self.parent.object_as(obj, target_class, *args, **kw)

        return result

    # Adapters.
    def register_adapters(self, factory, adaptee_class):
        """ Registers an adapter factory.

        'adaptee_class' can be either a class or the name of a class.

        """

        self._adapter_manager.register_adapters(factory, adaptee_class)

        return

    def unregister_adapters(self, factory):
        """ Unregisters an adapter factory. """

        self._adapter_manager.unregister_adapters(factory)

        return

    def register_instance_adapters(self, factory, obj):
        """ Registers an adapter factory for an individual instance.

        A factory can be in exactly one manager (as it uses the manager's type
        system).

        """

        self._adapter_manager.register_instance_adapters(factory, obj)

        return

    def unregister_instance_adapters(self, factory, obj):
        """ Unregisters an adapter factory for an individual instance.

        A factory can be in exactly one manager (as it uses the manager's type
        system).

        """

        self._adapter_manager.unregister_instance_adapters(factory, obj)

        return

    def register_type_adapters(self, factory, adaptee_class):
        """ Registers an adapter factory.

        'adaptee_class' can be either a class or the name of a class.

        """

        self._adapter_manager.register_type_adapters(factory, adaptee_class)

        return

    def unregister_type_adapters(self, factory):
        """ Unregisters an adapter factory. """

        self._adapter_manager.unregister_type_adapters(factory)

        return

    # Categories.
    #
    # Currently, there is no technical reason why we have this convenience
    # method to add categories. However, it may well turn out to be useful to
    # track all categories added via the type manager.
    def add_category(self, klass, category_class):
        """ Adds a category to a class. """

        klass.add_trait_category(category_class)

        return

    # Hooks.
    #
    # Currently, there is no technical reason why we have these convenience
    # methods to add and remove hooks. However, it may well turn out to be
    # useful to track all hooks added via the type manager.
    def add_pre(self, klass, method_name, callable):
        """ Adds a pre-hook to method 'method_name' on class 'klass. """

        add_pre(klass, method_name, callable)

        return

    def add_post(self, klass, method_name, callable):
        """ Adds a post-hook to method 'method_name' on class 'klass. """

        add_post(klass, method_name, callable)

        return

    def remove_pre(self, klass, method_name, callable):
        """ Removes a pre-hook to method 'method_name' on class 'klass. """

        remove_pre(klass, method_name, callable)

        return

    def remove_post(self, klass, method_name, callable):
        """ Removes a post-hook to method 'method_name' on class 'klass. """

        remove_post(klass, method_name, callable)

        return

    ###########################################################################
    # Private interface.
    ###########################################################################

    #### Initializers #########################################################

    def __adapter_manager_default(self):
        """ Initializes the '_adapter_manager' trait. """

        return AdapterManager(type_system=self.type_system)

    #### Methods ##############################################################

    def _is_factory_for(self, obj, target_class, *args, **kw):
        """ Returns True iff the object is a factory for the target class. """

        is_factory_for = self.type_system.is_a(obj, Factory) \
                         and obj.can_create(target_class, *args, **kw)

        return is_factory_for
Exemple #16
0
class EXDesignReader(HasTraits):
    '''Read the data from the directory

    The design is described in semicolon-separated
    csv file providing the information about
    design parameters.

    Each file has the name n.txt
    '''

    #--------------------------------------------------------------------
    # Specification of the design - factor list, relative paths, etc
    #--------------------------------------------------------------------
    open_exdesign = Button()

    def _open_exdesign_fired(self):
        file_name = open_file(filter=['*.eds'],
                              extensions=[FileInfo(), TextInfo()])
        if file_name != '':
            self.exdesign_spec_file = file_name

    exdesign_spec_file = File

    def _exdesign_spec_file_changed(self):
        print('changed file')
        f = file(self.exdesign_spec_file)
        str = f.read()
        self.exdesign_spec = eval('ExDesignSpec( %s )' % str)

    exdesign_spec = Instance(ExDesignSpec)

    def _exdesign_spec_default(self):
        return ExDesignSpec()

    @on_trait_change('exdesign_spec')
    def _reset_design_file(self):
        dir = os.path.dirname(self. exdesign_spec_file)
        exdesign_file = self.exdesign_spec.design_file
        self.design_file = os.path.join(dir, exdesign_file)

    #--------------------------------------------------------------------
    # file containing the association between the factor combinations
    # and data files having the data
    #--------------------------------------------------------------------
    design_file = File

    def _design_file_changed(self):
        self.exdesign = self._read_exdesign()

    exdesign_table_columns = Property(List, depends_on='exdesign_spec+')

    @cached_property
    def _get_exdesign_table_columns(self):
        return [ObjectColumn(name=ps[2],
                             editable=False,
                             width=0.15) for ps in self.exdesign_spec.factors]

    exdesign = List(Any)

    def _exdesign_default(self):
        return self._read_exdesign()

    def _read_exdesign(self):
        ''' Read the experiment design. 
        '''
        if exists(self.design_file):
            reader = csv.reader(open(self.design_file, 'r'), delimiter=';')

            data_dir = os.path.join(os.path.dirname(self.design_file),
                                    self.exdesign_spec.data_dir)

            return [ExRun(self, row, data_dir=data_dir) for row in reader]
        else:
            return []

    selected_exrun = Instance(ExRun)

    def _selected_exrun_default(self):
        if len(self.exdesign) > 0:
            return self.exdesign[0]
        else:
            return None

    last_exrun = Instance(ExRun)

    selected_exruns = List(ExRun)

    #------------------------------------------------------------------
    # Array plotting
    #-------------------------------------------------------------------
    # List of arrays to be plotted
    data = Instance(AbstractPlotData)

    def _data_default(self):
        return ArrayPlotData(x=array([]), y=array([]))

    @on_trait_change('selected_exruns')
    def _rest_last_exrun(self):
        if len(self.selected_exruns) > 0:
            self.last_exrun = self.selected_exruns[-1]

    @on_trait_change('selected_exruns')
    def _reset_data(self):
        '''
        '''
        runs, xlabels, ylabels, ylabels_fitted = self._generate_data_labels()
        for name in list(self.plot.plots.keys()):
            self.plot.delplot(name)

        for idx, exrun in enumerate(self.selected_exruns):
            if xlabels[idx] not in self.plot.datasources:
                self.plot.datasources[xlabels[idx]] = ArrayDataSource(exrun.xdata,
                                                                      sort_order='none')
            if ylabels[idx] not in self.plot.datasources:
                self.plot.datasources[ylabels[idx]] = ArrayDataSource(exrun.ydata,
                                                                      sort_order='none')

            if ylabels_fitted[idx] not in self.plot.datasources:
                self.plot.datasources[ylabels_fitted[idx]] = ArrayDataSource(exrun.polyfit,
                                                                             sort_order='none')

        for run, xlabel, ylabel, ylabel_fitted in zip(runs, xlabels, ylabels, ylabels_fitted):
            self.plot.plot((xlabel, ylabel), color='brown')
            self.plot.plot((xlabel, ylabel_fitted), color='blue')

    def _generate_data_labels(self):
        ''' Generate the labels consisting of the axis and run-number.
        '''
        return ([e.std_num for e in self.selected_exruns],
                ['x-%d' % e.std_num for e in self.selected_exruns],
                ['y-%d' % e.std_num for e in self.selected_exruns],
                ['y-%d-fitted' % e.std_num for e in self.selected_exruns])

    plot = Instance(Plot)

    def _plot_default(self):
        p = Plot()
        p.tools.append(PanTool(p))
        p.overlays.append(ZoomTool(p))
        return p

    view_traits = View(HSplit(VGroup(Item('open_exdesign',
                                          style='simple'),
                                     Item('exdesign',
                                          editor=exrun_table_editor,
                                          show_label=False, style='custom')
                                     ),
                              VGroup(Item('last_exrun@',
                                          show_label=False),
                                     Item('plot',
                                          editor=ComponentEditor(),
                                          show_label=False,
                                          resizable=True
                                          ),
                                     ),
                              ),
                       #                        handler = EXDesignReaderHandler(),
                       resizable=True,
                       buttons=[OKButton, CancelButton],
                       height=1.,
                       width=1.)
Exemple #17
0
class Volume(Module):
    """The Volume module visualizes scalar fields using volumetric
    visualization techniques.  This supports ImageData and
    UnstructuredGrid data.  It also supports the FixedPointRenderer
    for ImageData.  However, the performance is slow so your best bet
    is probably with the ImageData based renderers.
    """

    # The version of this class.  Used for persistence.
    __version__ = 0

    volume_mapper_type = DEnum(values_name='_mapper_types',
                               desc='volume mapper to use')

    ray_cast_function_type = DEnum(values_name='_ray_cast_functions',
                                   desc='Ray cast function to use')
    
    volume = ReadOnly

    volume_mapper = Property(record=True)

    volume_property = Property(record=True)

    ray_cast_function = Property(record=True)

    lut_manager = Instance(VolumeLUTManager, args=(), allow_none=False,
                           record=True)

    input_info = PipelineInfo(datasets=['image_data',
                                        'unstructured_grid'],
                              attribute_types=['any'],
                              attributes=['scalars'])    

    ########################################
    # View related code.

    update_ctf = Button('Update CTF')
    
    view = View(Group(Item(name='_volume_property', style='custom',
                           editor=CustomEditor(gradient_editor_factory),
                           resizable=True),
                      Item(name='update_ctf'),
                      label='CTF',
                      show_labels=False),
                Group(Item(name='volume_mapper_type'),
                      Group(Item(name='_volume_mapper',
                                 style='custom',
                                 resizable=True),
                            show_labels=False
                            ),
                      Item(name='ray_cast_function_type'),
                      Group(Item(name='_ray_cast_function',
                                 enabled_when='len(_ray_cast_functions) > 0',
                                 style='custom',
                                 resizable=True),
                            show_labels=False),
                      label='Mapper',
                      ),
                Group(Item(name='_volume_property', style='custom',
                           resizable=True),
                      label='Property',
                      show_labels=False),                
                Group(Item(name='volume', style='custom',
                           editor=InstanceEditor(),
                           resizable=True),
                      label='Volume',
                      show_labels=False),
                Group(Item(name='lut_manager', style='custom',
                           resizable=True),
                      label='Legend',
                      show_labels=False),                
                resizable=True
                )

    ########################################
    # Private traits
    _volume_mapper = Instance(tvtk.AbstractVolumeMapper)
    _volume_property = Instance(tvtk.VolumeProperty)
    _ray_cast_function = Instance(tvtk.Object)

    _mapper_types = List(Str, ['TextureMapper2D', 'RayCastMapper', ])

    _available_mapper_types = List(Str)

    _ray_cast_functions = List(Str)

    current_range = Tuple

    # The color transfer function.
    _ctf = Instance(ColorTransferFunction)
    # The opacity values.
    _otf = Instance(PiecewiseFunction)

    ######################################################################
    # `object` interface
    ######################################################################
    def __get_pure_state__(self):
        d = super(Volume, self).__get_pure_state__()
        d['ctf_state'] = save_ctfs(self._volume_property)
        for name in ('current_range', '_ctf', '_otf'):
            d.pop(name, None)
        return d

    def __set_pure_state__(self, state):
        self.volume_mapper_type = state['_volume_mapper_type']
        state_pickler.set_state(self, state, ignore=['ctf_state'])
        ctf_state = state['ctf_state']
        ctf, otf = load_ctfs(ctf_state, self._volume_property)
        self._ctf = ctf
        self._otf = otf
        self._update_ctf_fired()
        
    ######################################################################
    # `Module` interface
    ######################################################################
    def start(self):
        super(Volume, self).start()
        self.lut_manager.start()

    def stop(self):
        super(Volume, self).stop()
        self.lut_manager.stop()
    
    def setup_pipeline(self):
        """Override this method so that it *creates* the tvtk
        pipeline.
        """
        v = self.volume = tvtk.Volume()
        vp = self._volume_property = tvtk.VolumeProperty()

        self._ctf = ctf = default_CTF(0, 255)
        self._otf = otf = default_OTF(0, 255)
        vp.set_color(ctf)
        vp.set_scalar_opacity(otf)
        vp.shade = True
        vp.interpolation_type = 'linear'
        v.property = vp

        v.on_trait_change(self.render)
        vp.on_trait_change(self.render)

        available_mappers = find_volume_mappers()
        if is_volume_pro_available():
            self._mapper_types.append('VolumeProMapper')
            available_mappers.append('VolumeProMapper')

        self._available_mapper_types = available_mappers
        if 'FixedPointVolumeRayCastMapper' in available_mappers:
            self._mapper_types.append('FixedPointVolumeRayCastMapper')
        
        self.actors.append(v)
    
    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.

        This method is invoked (automatically) when any of the inputs
        sends a `pipeline_changed` event.
        """
        mm = self.module_manager
        if mm is None:
            return

        input = mm.source.outputs[0]

        ug = hasattr(tvtk, 'UnstructuredGridVolumeMapper')
        if ug:
            if not input.is_a('vtkImageData') \
                   and not input.is_a('vtkUnstructuredGrid'):
                error('Volume rendering only works with '\
                      'StructuredPoints/ImageData/UnstructuredGrid datasets')
                return
        elif not input.is_a('vtkImageData'):
            error('Volume rendering only works with '\
                  'StructuredPoints/ImageData datasets')
            return
    
        self._setup_mapper_types()
        self._setup_current_range()
        self._volume_mapper_type_changed(self.volume_mapper_type)
        self._update_ctf_fired()
        self.pipeline_changed = True

    def update_data(self):
        """Override this method so that it flushes the vtk pipeline if
        that is necessary.

        This method is invoked (automatically) when any of the inputs
        sends a `data_changed` event.
        """
        self._setup_mapper_types()
        self._setup_current_range()
        self._update_ctf_fired()
        self.data_changed = True
    
    ######################################################################
    # Non-public methods.
    ######################################################################
    def _setup_mapper_types(self):
        """Sets up the mapper based on input data types.
        """
        input = self.module_manager.source.outputs[0]
        if input.is_a('vtkUnstructuredGrid'):
            if hasattr(tvtk, 'UnstructuredGridVolumeMapper'):
                check = ['UnstructuredGridVolumeZSweepMapper',
                         'UnstructuredGridVolumeRayCastMapper',
                         ]
                mapper_types = []
                for mapper in check:
                    if mapper in self._available_mapper_types:
                        mapper_types.append(mapper)
                if len(mapper_types) == 0:
                    mapper_types = ['']
                self._mapper_types = mapper_types
                return
        else:
            if input.point_data.scalars.data_type not in \
               [vtkConstants.VTK_UNSIGNED_CHAR,
                vtkConstants.VTK_UNSIGNED_SHORT]:
                if 'FixedPointVolumeRayCastMapper' \
                       in self._available_mapper_types:
                    self._mapper_types = ['FixedPointVolumeRayCastMapper']
                else:
                    error('Available volume mappers only work with \
                    unsigned_char or unsigned_short datatypes')
            else:
                mapper_types = ['TextureMapper2D', 'RayCastMapper']
                check = ['FixedPointVolumeRayCastMapper',
                         'VolumeProMapper'
                         ]
                for mapper in check:
                    if mapper in self._available_mapper_types:
                        mapper_types.append(mapper)
                self._mapper_types = mapper_types
        
    def _setup_current_range(self):
        mm = self.module_manager
        # Set the default name and range for our lut.
        lm = self.lut_manager
        slm = mm.scalar_lut_manager 
        lm.set(default_data_name=slm.default_data_name,
               default_data_range=slm.default_data_range)
        
        # Set the current range.
        input = mm.source.outputs[0]
        sc = input.point_data.scalars
        if sc is not None:
            rng = sc.range
        else:
            error('No scalars in input data!')
            rng = (0, 255)

        if self.current_range != rng:
            self.current_range = rng
            
    def _get_volume_mapper(self):
        return self._volume_mapper

    def _get_volume_property(self):
        return self._volume_property
    
    def _get_ray_cast_function(self):
        return self._ray_cast_function

    def _volume_mapper_type_changed(self, value):
        mm = self.module_manager
        if mm is None:
            return

        old_vm = self._volume_mapper
        if old_vm is not None:
            old_vm.on_trait_change(self.render, remove=True)

        if value == 'RayCastMapper':
            new_vm = tvtk.VolumeRayCastMapper()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['RayCastCompositeFunction',
                                        'RayCastMIPFunction',
                                        'RayCastIsosurfaceFunction']
            new_vm.volume_ray_cast_function = tvtk.VolumeRayCastCompositeFunction()
        elif value == 'TextureMapper2D':
            new_vm = tvtk.VolumeTextureMapper2D()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'VolumeProMapper':
            new_vm = tvtk.VolumeProMapper()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'FixedPointVolumeRayCastMapper':            
            new_vm = tvtk.FixedPointVolumeRayCastMapper()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'UnstructuredGridVolumeRayCastMapper':
            new_vm = tvtk.UnstructuredGridVolumeRayCastMapper()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']
        elif value == 'UnstructuredGridVolumeZSweepMapper':
            new_vm = tvtk.UnstructuredGridVolumeZSweepMapper()
            self._volume_mapper = new_vm
            self._ray_cast_functions = ['']

        new_vm.input = mm.source.outputs[0]
        self.volume.mapper = new_vm
        new_vm.on_trait_change(self.render)
        
    def _update_ctf_fired(self):
        set_lut(self.lut_manager.lut, self._volume_property)
        self.render()

    def _current_range_changed(self, old, new):
        rescale_ctfs(self._volume_property, new)
        self.render()
        
    def _ray_cast_function_type_changed(self, old, new):
        rcf = self.ray_cast_function
        if len(old) > 0:
            rcf.on_trait_change(self.render, remove=True)

        if len(new) > 0:
            new_rcf = getattr(tvtk, 'Volume%s'%new)()
            new_rcf.on_trait_change(self.render)
            self._volume_mapper.volume_ray_cast_function = new_rcf
            self._ray_cast_function = new_rcf
        else:
            self._ray_cast_function = None

        self.render()

    def _scene_changed(self, old, new):        
        super(Volume, self)._scene_changed(old, new)
        self.lut_manager.scene = new
Exemple #18
0
class ExRun(HasTraits):
    '''
    Represent a single test specifying the design parameters.
    and access to the measured data.
    '''
    data_dir = Str
    exdesign_reader = WeakRef

    def __init__(self, exdesign_reader, row, **kw):
        '''Retrieve the traits from the exdesign reader
        '''
        self.exdesign_reader = exdesign_reader
        factors = self.exdesign_reader.exdesign_spec.factors
        for idx, ps in enumerate(factors):
            cmd = '%s( %s("%s") )' % (ps[0], ps[1], row[idx])
            self.add_trait(ps[2], eval(cmd))
        super(ExRun, self).__init__(**kw)

    data_file = File

    def _data_file_default(self):
        return os.path.join(self.data_dir, self._get_file_name())

    @on_trait_change('data_file')
    def _reset_data_file(self):
        self.data_file = os.path.join(self.data_dir, self._get_file_name())

    def _get_file_name(self):
        return str(self.std_num) + '.txt'

    def _get_file_name(self):
        fname = eval(self.exdesign_reader.exdesign_spec.data_file_name)
        print('fname', fname)
        return fname

    _arr = Property(Array(float), depends_on='data_file')

    def _get__arr(self):
        return loadtxt(self.data_file, skiprows=2,
                       delimiter=self.exdesign_reader.exdesign_spec.data_delimiter,
                       converters=self.exdesign_reader.exdesign_spec.data_converters)

    xdata = Property(Array(float), depends_on='data_file')

    @cached_property
    def _get_xdata(self):
        return self._arr[:, 0]

    ydata = Property(Array(float), depends_on='data_file')

    @cached_property
    def _get_ydata(self):
        return self._arr[:, 1]

    max_stress_idx = Property(Int)

    def _get_max_stress_idx(self):
        return argmax(self._get_ydata())

    max_stress = Property(Float)

    def _get_max_stress(self):
        return self.ydata[self.max_stress_idx]

    strain_at_max_stress = Property(Float)

    def _get_strain_at_max_stress(self):
        return self.xdata[self.max_stress_idx]

    # get the ascending branch of the response curve
    xdata_asc = Property(Array(float))

    def _get_xdata_asc(self):
        return self.xdata[:self.max_stress_idx + 1]

    ydata_asc = Property(Array(float))

    def _get_ydata_asc(self):
        return self.ydata[:self.max_stress_idx + 1]

    # interplate the polynomial
    polyfit = Property(Array(float))

    def _get_polyfit(self):
        #
        # define x array to evaluate the derivatives properly
        #
        #xdarr = arange(0, self.xdata_asc[-1], 0.01)
        #
        # get the fit with 10-th-order polynomial
        #
        p = polyfit(self.xdata_asc, self.ydata_asc, 5)
        #
        # define the polynomial function
        #
        pf = poly1d(p)
        #
        # define universal function for the value
        # (used just for visualization)
        #
        pfun = frompyfunc(pf, 1, 1)

        approx_dat = array(pfun(self.xdata_asc), dtype=float)
        return approx_dat

    traits_view = View(Item('data_dir', style='readonly'),
                       Item('max_stress_idx', style='readonly'),
                       Item('max_stress', style='readonly'),
                       Item('strain_at_max_stress', style='readonly'),
                       )
Exemple #19
0
class AdapterManager(HasTraits):
    """ A manager for adapter factories. """

    #### 'AdapterManager' interface ###########################################

    # All registered type-scope factories by the type of object that they
    # adapt.
    #
    # The dictionary is keyed by the *name* of the class rather than the class
    # itself to allow for adapter factory proxies to register themselves
    # without having to load and create the factories themselves (i.e., to
    # allow us to lazily load adapter factories contributed by plugins). This
    # is a slight compromise as it is obviously geared towards use in Envisage,
    # but it doesn't affect the API other than allowing a class OR a string to
    # be passed to 'register_adapters'.
    #
    # { String adaptee_class_name : List(AdapterFactory) factories }
    type_factories = Property(Dict)

    # All registered instance-scope factories by the object that they adapt.
    #
    # { id(obj) : List(AdapterFactory) factories }
    instance_factories = Property(Dict)

    # The type system used by the manager (it determines 'is_a' relationships
    # and type MROs etc). By default we use standard Python semantics.
    type_system = Instance(AbstractTypeSystem, PythonTypeSystem())

    #### Private interface ####################################################

    # All registered type-scope factories by the type of object that they
    # adapt.
    _type_factories = Dict

    # All registered instance-scope factories by the object that they adapt.
    _instance_factories = Dict

    ###########################################################################
    # 'AdapterManager' interface.
    ###########################################################################

    #### Properties ###########################################################

    def _get_type_factories(self):
        """ Returns all registered type-scope factories. """

        return self._type_factories.copy()

    def _get_instance_factories(self):
        """ Returns all registered instance-scope factories. """

        return self._instance_factories.copy()

    #### Methods ##############################################################

    def adapt(self, adaptee, target_class, *args, **kw):
        """ Returns an adapter that adapts an object to the target class.

        'adaptee' is the object that we want to adapt.
        'target_class' is the class that the adaptee should be adapted to.

        Returns None if no such adapter exists.

        """

        # If the object is already an instance of the target class then we
        # simply return it.
        if self.type_system.is_a(adaptee, target_class):
            adapter = adaptee

        # Otherwise, look at each class in the adaptee's MRO to see if there
        # is an adapter factory registered that can adapt the object to the
        # target class.
        else:
            # Look for instance-scope adapters first.
            adapter = self._adapt_instance(adaptee, target_class, *args, **kw)

            # If no instance-scope adapter was found then try type-scope
            # adapters.
            if adapter is None:
                for adaptee_class in self.type_system.get_mro(type(adaptee)):
                    adapter = self._adapt_type(
                        adaptee, adaptee_class, target_class, *args, **kw
                    )
                    if adapter is not None:
                        break

        return adapter

    def register_instance_adapters(self, factory, obj):
        """ Registers an instance-scope adapter factory.

        A factory can be in exactly one manager (as it uses the manager's type
        system).

        """

        factories = self._instance_factories.setdefault(id(obj), [])
        factories.append(factory)

        # A factory can be in exactly one manager.
        factory.adapter_manager = self

        return

    def unregister_instance_adapters(self, factory, obj):
        """ Unregisters an instance scope adapter factory.

        A factory can be in exactly one manager (as it uses the manager's type
        system).

        """

        factories = self._instance_factories.setdefault(id(obj), [])
        if factory in factories:
            factories.remove(factory)

        # A factory can be in exactly one manager.
        factory.adapter_manager = None

        return

    def register_type_adapters(self, factory, adaptee_class):
        """ Registers a type-scope adapter factory.

        'adaptee_class' can be either a class object or the name of a class.

        A factory can be in exactly one manager (as it uses the manager's type
        system).

        """

        if isinstance(adaptee_class, basestring):
            adaptee_class_name = adaptee_class

        else:
            adaptee_class_name = self._get_class_name(adaptee_class)

        factories = self._type_factories.setdefault(adaptee_class_name, [])
        factories.append(factory)

        # A factory can be in exactly one manager.
        factory.adapter_manager = self

        return

    def unregister_type_adapters(self, factory):
        """ Unregisters a type-scope adapter factory. """

        for adaptee_class_name, factories in self._type_factories.items():
            if factory in factories:
                factories.remove(factory)

        # The factory is no longer deemed to be part of this manager.
        factory.adapter_manager = None

        return

    #### DEPRECATED ###########################################################

    def register_adapters(self, factory, adaptee_class):
        """ Registers an adapter factory.

        'adaptee_class' can be either a class object or the name of a class.

        A factory can be in exactly one manager (as it uses the manager's type
        system).

        """

        print 'DEPRECATED: use "register_type_adapters" instead.'

        self.register_type_adapters(factory, adaptee_class)

        return

    def unregister_adapters(self, factory):
        """ Unregisters an adapter factory. """

        print 'DEPRECATED: use "unregister_type_adapters" instead.'

        self.unregister_type_adapters(factory)

        return

    ###########################################################################
    # Private interface.
    ###########################################################################

    def _adapt_instance(self, adaptee, target_class, *args, **kw):
        """ Returns an adapter that adaptes an object to the target class.

        Returns None if no such adapter exists.

        """

        for factory in self._instance_factories.get(id(adaptee), []):
            adapter = factory.adapt(adaptee, target_class, *args, **kw)
            if adapter is not None:
                break

        else:
            adapter = None

        return adapter

    def _adapt_type(self, adaptee, adaptee_class, target_class, *args, **kw):
        """ Returns an adapter that adapts an object to the target class.

        Returns None if no such adapter exists.

        """

        adaptee_class_name = self._get_class_name(adaptee_class)

        for factory in self._type_factories.get(adaptee_class_name, []):
            adapter = factory.adapt(adaptee, target_class, *args, **kw)
            if adapter is not None:
                break

        else:
            adapter = None

        return adapter

    def _get_class_name(self, klass):
        """ Returns the full class name for a class. """

        return "%s.%s" % (klass.__module__, klass.__name__)
Exemple #20
0
class TreeItem(HasTraits):
    """ A generic base-class for items in a tree data structure. """

    #### 'TreeItem' interface #################################################
    
    # Does this item allow children?
    allows_children = Bool(True)
    
    # The item's children.
    children = List(Instance('TreeItem'))

    # Arbitrary data associated with the item.
    data = Any

    # Does the item have any children?
    has_children = Property(Bool)

    # The item's parent.
    parent = Instance('TreeItem')

    ###########################################################################
    # 'object' interface.
    ###########################################################################

    def __str__(self):
        """ Returns the informal string representation of the object. """

        if self.data is None:
            s = ''

        else:
            s = str(self.data)

        return s
    
    ###########################################################################
    # 'TreeItem' interface.
    ###########################################################################

    #### Properties ###########################################################

    # has_children
    def _get_has_children(self):
        """ True iff the item has children. """

        return len(self.children) != 0
    
    #### Methods ##############################################################

    def append(self, child):
        """ Appends a child to this item.

        This removes the child from its current parent (if it has one).

        """

        return self.insert(len(self.children), child)

    def insert(self, index, child):
        """ Inserts a child into this item at the specified index.

        This removes the child from its current parent (if it has one).

        """

        if child.parent is not None:
            child.parent.remove(child)

        child.parent = self
        self.children.insert(index, child)

        return child

    def remove(self, child):
        """ Removes a child from this item. """

        child.parent = None
        self.children.remove(child)

        return child

    def insert_before(self, before, child):
        """ Inserts a child into this item before the specified item.

        This removes the child from its current parent (if it has one).

        """

        index = self.children.index(before)
        
        self.insert(index, child)

        return (index, child)

    def insert_after(self, after, child):
        """ Inserts a child into this item after the specified item.

        This removes the child from its current parent (if it has one).

        """

        index = self.children.index(after)
        
        self.insert(index + 1, child)

        return (index, child)
Exemple #21
0
class Base(TreeNodeObject):
    # The version of this class.  Used for persistence.
    __version__ = 0
    
    ########################################
    # Traits

    # The scene (RenderWindow) associated with this component.
    scene = Instance(TVTKScene, record=False)

    # Is this object running as part of the mayavi pipeline.
    running = Property(Bool, record=False)

    # The object's name.
    name = Str('')

    # The default icon.
    icon = 'module.ico'

    # The human readable type for this object
    type = Str('', record=False)

    # Is this object visible or not. 
    visible = Bool(True, desc='if the object is visible')

    # Extend the children list with an AdderNode when a TreeEditor needs it.
    children_ui_list = Property(depends_on=['children'], record=False)

    # The parent of this object, i.e. self is an element of the parents
    # children.  If there is no notion of a parent/child relationship
    # this trait is None.
    parent = WeakRef(record=False)

    # A helper for the right click menus, context sensitivity etc.
    menu_helper = Instance(HasTraits, record=False)

    # Our recorder.
    recorder = Instance(Recorder, record=False)

    ##################################################
    # Private traits
    _is_running = Bool(False)

    # This is used to save the state of the object when it is not
    # running.  When the object "starts", the state is loaded.  This
    # is done because a stopped object will not have a meaningful VTK
    # pipeline setup, so setting its state will lead to all kinds of
    # errors.
    _saved_state = Str('')

    # Hide and show actions
    _HideShowAction = Instance(Action,  
                               kw={'name': 'Hide/Show', 
                                   'action': 'object._hideshow'}, )

    # The menu shown on right-click for this.
    _menu = Instance(Menu, transient=True)

    # Path to the icon for this object.
    _icon_path = Str()

    # Adder node: a dialog to add children to this object
    _adder_node_class = None

    # Name of the file that may host the hand-crafted view
    _view_filename = Str(transient=True)

    # Hand crafted view.
    _module_view = Instance(View, transient=True)

    # Work around problem with HasPrivateTraits.
    __ = Python
    ##################################################

    ######################################################################
    # `object` interface
    ######################################################################
    def __get_pure_state__(self):
        """Method used by the state_pickler.
        """
        d = self.__dict__.copy()
        for attr in ('scene', '_is_running', '__sync_trait__',
                     '__traits_listener__', '_icon_path',
                     '_menu', '_HideShowAction', 'menu_helper',
                     'parent', 'parent_', '_module_view',
                     '_view_filename', 'mlab_source'):
            d.pop(attr, None)
        return d

    def __getstate__(self):
        """Allows standard pickle to work via the state_pickler.
        """
        return state_pickler.dumps(self)

    def __setstate__(self, str_state):
        """Allows standard pickle to work via the state_pickler.
        """
        self.__init__()
        # Get the state from the string and update it.
        state = state_pickler.loads_state(str_state)
        state_pickler.update_state(state)
        # Save the state and load it if we are running.
        self._saved_state = cPickle.dumps(state)
        if self.running:
            self._load_saved_state()

    def __deepcopy__(self, memo):
        """Method used by copy.deepcopy().  This also uses the
        state_pickler to work correctly.
        """
        # Create a new instance.
        new = self.__class__()
        # If we have a saved state, use it for the new instance.  If
        # not, get our state and save that.
        saved_state = self._saved_state
        if len(saved_state) == 0:
            state = state_pickler.get_state(self)
            saved_state = cPickle.dumps(state)
        new._saved_state = saved_state
        # In the unlikely case that a new instance is running, load
        # the saved state.
        if new.running:
            new._load_saved_state()
        return new

    ######################################################################
    # `Base` interface
    ######################################################################
    def start(self):
        """Invoked when this object is added to the mayavi pipeline.
        """
        self.running = True
        self._load_saved_state()

    def stop(self):
        """Invoked when this object is removed from the mayavi
        pipeline.
        """
        self.running = False

    def add_child(self, child):
        """This method intelligently adds a child to this object in
        the MayaVi pipeline.        
        """
        raise NotImplementedError

    def remove_child(self, child):
        """Remove specified child from our children.
        """
        raise NotImplementedError()

    def remove(self):
        """Remove ourselves from the mayavi pipeline.
        """
        if self.parent is not None:
            e = get_engine(self)
            self.parent.remove_child(self)
            if e.current_object is self:
                e.current_object = self.parent
    
    def render(self):
        """Invokes render on the scene, this in turn invokes Render on
        the VTK pipeline.
        """
        s = self.scene
        if s is not None:
            s.render()

    def dialog_view(self):
        """ Returns a view with an icon and a title.
        """
        view = self.trait_view()
        icon = self._icon_path + os.sep + 'images' + os.sep \
                            + self.icon
        view.icon = ImageResource(icon)
        view.title = "Edit%s: %s" % (self.type, self.name)
        view.buttons = ['OK', 'Cancel']
        return view

    def trait_view(self, name = None, view_element = None ):
        """ Gets or sets a ViewElement associated with an object's class.

        Overridden here to search for a separate file in the same directory 
        for the view to use for this object. The view should be declared in 
        the file named <class name>_view. If a file with this name is not 
        found, the trait_view method on the base class will be called.
        """

        # If a name is specified, then call the HasTraits trait_view method
        # which will return (or assign) the *view_element* associated with 
        # *name*.
        if name:
            return super(Base, self).trait_view(name, view_element)

        view = self._load_view_cached(name, view_element)
        # Uncomment this when developping views.
        #view = self._load_view_non_cached(name, view_element)
        return view
        
    ######################################################################
    # `TreeNodeObject` interface
    ######################################################################
    def tno_get_label(self, node):
        """Gets the label to display for a specified object.
        """
        if self.name == '':
            self.name = self.__class__.__name__
        return self.name

    def tno_get_view(self, node):
        """Gets the View to use when editing an object.
        """
        view = self.trait_view()
        view.kind = 'subpanel'
        return view

    def tno_confirm_delete(self, node):
        """Confirms that a specified object can be deleted or not.
        """
        if preference_manager.root.confirm_delete:
            return None
        else:
            return True

    def tno_get_menu ( self, node ):
        """ Returns the contextual pop-up menu.
        """
        if self._menu is None:
            return super(Base, self).tno_get_menu(node)
        return self._menu

    def tno_get_icon(self, node, is_expanded):
        return self.icon

    def tno_get_icon_path(self, node):
        return self._icon_path

    def tno_delete_child(self, node, index):
        if len(self.children_ui_list) > len(self.children):
            del self.children[index - 1]
        else:
            del self.children[index]

    def tno_append_child(self, node, child):
        """ Appends a child to the object's children.
        """
        self.children.append(child)

    def tno_insert_child(self, node, index, child):
        """ Inserts a child into the object's children.
        """
        if len(self.children_ui_list) > len(self.children):
            idx = index -1
        else:
            idx = index
        self.children[idx:idx] = [child]

    ######################################################################
    # Non-public interface
    ######################################################################
    def _get_running(self):
        return self._is_running

    def _set_running(self, new):
        if self._is_running == new:
            return
        else:
            old = self._is_running
            self._is_running = new
            self.trait_property_changed('running', old, new)

    def _get_children_ui_list(self):
        """ Getter for Traits Property children_ui_list.
        
        For the base class, do not add anything to the children list.
        """
        if ((not preference_manager.root.show_helper_nodes or
                        len(self.children) > 0)
                or self._adder_node_class is None
                or (not self.type == ' scene' and
                    'none' in self.output_info.datasets)
                    # We can't use isinstance, as we would have circular
                    # imports
                ):
            return self.children
        else:
            return [self._adder_node_class(object=self),]

    @on_trait_change('children[]')
    def _trigger_children_ui_list(self, old, new):
        """ Trigger a children_ui_list change when scenes changed.
        """
        self.trait_property_changed('children_ui_list', old, new)

    def _visible_changed(self , value):
        # A hack to set the name when the tree view is not active.
        # `self.name` is set only when tno_get_label is called and this
        # is never called when the tree view is not shown leading to an
        # empty name.
        if len(self.name) == 0:
            self.tno_get_label(None)
        if value:
            #self._HideShowAction.name = "Hide"
            self.name = self.name.replace(' [Hidden]', '')
        else:
            #self._HideShowAction.name = "Show"
            n = self.name
            if ' [Hidden]' not in n:
                self.name = "%s [Hidden]" % n

    def _load_view_cached(self, name, view_element):
        """ Use a cached view for the object, for faster refresh.
        """
        if self._module_view is not None:
            view = self._module_view
        else:
            logger.debug("No view found for [%s] in [%s]. "
                         "Using the base class trait_view instead.", 
                             self, self._view_filename)
            view = super(Base, self).trait_view(name, view_element)
        return view

    def _load_view_non_cached(self, name, view_element):
        """ Loads the view by execing a file. Useful when tweaking
            views.
        """
        result = {}
        view_filename = self._view_filename 
        try:
            execfile(view_filename, {}, result)
            view = result['view']
        except IOError:
            logger.debug("No view found for [%s] in [%s]. "
                            "Using the base class trait_view instead.", 
                            self, view_filename)
            view = super(Base, self).trait_view(name, view_element)
        return view

    def _hideshow(self):
        if self.visible:
            self.visible = False
        else:
            self.visible = True

    def _load_saved_state(self):
        """Load the saved state (if any) of this object.
        """
        saved_state = self._saved_state
        if len(saved_state) > 0:
            state = cPickle.loads(saved_state)
            if hasattr(self, '__set_pure_state__'):
                self.__set_pure_state__(state)
            else:
                state_pickler.set_state(self, state)
            self._saved_state = ''

    def __view_filename_default(self):
        """ The name of the file that will host the view.
        """
        module = self.__module__.split('.')
        class_filename = module[-1] + '.py'
        module_dir_name = module[2:-1]
        base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        view_filename = reduce(os.path.join, 
                               [base_dir] + module_dir_name \
                               + UI_DIR_NAME + [class_filename])
        return view_filename


    def __module_view_default(self):
        """ Try to load a view for this object.
        """
        view_filename = self._view_filename
        try:
            result = imp.load_module('view', file(view_filename),
                            view_filename, ('.py', 'U', 1))
            view = result.view
        except:
            view = None
        return view


    def __menu_default(self):
        extras = []
        if self.menu_helper is not None:
            extras = self.menu_helper.actions + self._extra_menu_items()
        menu_actions = [Separator()] + extras + \
                       [Separator(), self._HideShowAction, Separator()] + \
                       deepcopy(standard_menu_actions)
        return Menu( *menu_actions)

    def __icon_path_default(self):
        return resource_path()

    def _extra_menu_items(self):
        """Override this to generate any new menu actions you want on
        the right click menu."""
        return []
Exemple #22
0
def _get_virtual_data(self, name):
    return self.data_name.items[self.trait(
        name).index].data_name_item_choice.choice_value


def _set_virtual_data(self, name, new_value):
    old_value = _get_virtual_data(self, name)
    if old_value != new_value:
        self.data_name.items[
                 self.trait( name ).index ].data_name_item_choice = \
            TemplateChoice( choice_value = new_value )

        self.trait_property_changed(name, old_value, new_value)


VirtualValue = Property(_get_virtual_data, _set_virtual_data)


class VirtualDataName(HasPrivateTraits):

    # The TemplateDataName this is a virtual copy of:
    data_name = Instance(TemplateDataName)

    # The data name description:
    description = Delegate('data_name', modify=True)

    # The 'virtual' traits of this object:
    value0 = VirtualValue(index=0)
    value1 = VirtualValue(index=1)
    value2 = VirtualValue(index=2)
    value3 = VirtualValue(index=3)
Exemple #23
0
class UndoManager(HasTraits):
    """ The UndoManager class is the default implementation of the
    IUndoManager interface.
    """

    implements(IUndoManager)

    #### 'IUndoManager' interface #############################################

    # This is the currently active command stack and may be None.  Typically it
    # is set when some sort of editor becomes active.
    active_stack = Instance('enthought.undo.api.ICommandStack')

    # This reflects the clean state of the currently active command stack.  It
    # is intended to support a "document modified" indicator in the GUI.  It is
    # maintained by the undo manager.
    active_stack_clean = Property(Bool)

    # This is the name of the command that can be redone.  It will be empty if
    # there is no command that can be redone.  It is maintained by the undo
    # manager.
    redo_name = Property(Unicode)

    # This is the sequence number of the next command to be performed.  It is
    # incremented immediately before a command is invoked (by its 'do()'
    # method).
    sequence_nr = Int

    # This event is fired when the index of a command stack changes.  The value
    # of the event is the stack that has changed.  Note that it may not be the
    # active stack.
    stack_updated = Event

    # This is the name of the command that can be undone.  It will be empty if
    # there is no command that can be undone.  It is maintained by the undo
    # manager.
    undo_name = Property(Unicode)

    ###########################################################################
    # 'IUndoManager' interface.
    ###########################################################################

    def redo(self):
        """ Redo the last undone command of the active command stack. """

        if self.active_stack is not None:
            self.active_stack.redo()

    def undo(self):
        """ Undo the last command of the active command stack. """

        if self.active_stack is not None:
            self.active_stack.undo()

    ###########################################################################
    # Private interface.
    ###########################################################################

    def _active_stack_changed(self, new):
        """ Handle a different stack becoming active. """

        # Pretend that the stack contents have changed.
        self.stack_updated = new

    def _get_active_stack_clean(self):
        """ Get the current clean state. """

        if self.active_stack is None:
            active_stack_clean = True
        else:
            active_stack_clean = self.active_stack.clean

        return active_stack_clean

    def _get_redo_name(self):
        """ Get the current redo name. """

        if self.active_stack is None:
            redo_name = ""
        else:
            redo_name = self.active_stack.redo_name

        return redo_name

    def _get_undo_name(self):
        """ Get the current undo name. """

        if self.active_stack is None:
            undo_name = ""
        else:
            undo_name = self.active_stack.undo_name

        return undo_name
Exemple #24
0
class TemplateDataNames(HasPrivateTraits):

    #-- Public Traits ----------------------------------------------------------

    # The data context to which bindings are made:
    context = Instance(ITemplateDataContext)

    # The current set of data names to be bound to the context:
    data_names = List(TemplateDataName)

    # The list of unresolved, required bindings:
    unresolved_data_names = Property(depends_on='data_names.resolved')

    # The list of optional bindings:
    optional_data_names = Property(depends_on='data_names.optional')

    # The list of unresolved optional bindings:
    unresolved_optional_data_names = Property(
        depends_on='data_names.[resolved,optional]')

    #-- Private Traits ---------------------------------------------------------

    # List of 'virtual' data names for use by table editor:
    virtual_data_names = List

    # The list of table editor columns:
    table_columns = Property(depends_on='data_names')  # List( ObjectColumn )

    #-- Traits View Definitions ------------------------------------------------

    view = View(
        Item('virtual_data_names',
             show_label=False,
             style='custom',
             editor=table_editor))

    #-- Property Implementations -----------------------------------------------

    @cached_property
    def _get_unresolved_data_names(self):
        return [
            dn for dn in self.data_names
            if (not dn.resolved) and (not dn.optional)
        ]

    @cached_property
    def _get_optional_data_names(self):
        return [dn for dn in self.data_names if dn.optional]

    @cached_property
    def _get_unresolved_optional_data_names(self):
        return [
            dn for dn in self.data_names if (not dn.resolved) and dn.optional
        ]

    @cached_property
    def _get_table_columns(self):
        n = max([len(dn.items) for dn in self.data_names])
        if n == 1:
            return std_columns + [
                BindingsColumn(name='value0', label='Name', width=0.43)
            ]
        width = 0.43 / n
        return (std_columns + [
            BindingsColumn(name='value%d' % i,
                           index=i,
                           label='Name %d' % (i + 1),
                           width=width) for i in range(n)
        ])

    #-- Trait Event Handlers ---------------------------------------------------

    def _context_changed(self, context):
        for data_name in self.data_names:
            data_name.context = context

    def _data_names_changed(self, old, new):
        """ Handles the list of 'data_names' being changed.
        """
        # Make sure that all of the names are unique:
        new = set(new)

        # Update the old and new context links:
        self._update_contexts(old, new)

        # Update the list of virtual names based on the new set:
        dns = [VirtualDataName(data_name=dn) for dn in new]
        dns.sort(lambda l, r: cmp(l.description, r.description))
        self.virtual_data_names = dns

    def _data_names_items_changed(self, event):
        # Update the old and new context links:
        old, new = event.old, event.new
        self._update_contexts(old, new)

        # Update the list of virtual names based on the old and new sets:
        i = event.index
        self.virtual_data_names[i:i + len(old)] = [
            VirtualDataName(data_name=dn) for dn in new
        ]

    #-- Private Methods --------------------------------------------------------

    def _update_contexts(self, old, new):
        """ Updates the data context for an old and new set of data names.
        """
        for data_name in old:
            data_name.context = None

        context = self.context
        for data_name in new:
            data_name.context = context
Exemple #25
0
class TreeBranchNode(TreeNodeObject):
    """Represents a branch in the tree view.  The `children` trait
    produces an iterable that represents the children of the branch.
    """
    # The tvtk object being wrapped.
    object = Instance(HasTraits)
    # Children of the object.
    children = Property
    # Name to show on the view.
    name = Property(Str, depends_on='object.name?')
    # Tree generator to use.
    tree_generator = Instance(TreeGenerator)
    # Cache of children.
    children_cache = Dict

    # Work around problem with HasPrivateTraits.
    __ = Python

    def __init__(self, **traits):
        super(TreeBranchNode, self).__init__(**traits)

    def __del__(self):
        try:
            self._remove_listners()
        except:
            pass

    #def __hash__(self):
    #    return hash(tvtk.to_vtk(self.object))

    def _get_children_from_cache(self):
        return [x for x in self.children_cache.values() if x is not None]

    def _create_children(self):
        kids = self.tree_generator.get_children(self.object)
        self.children_cache = kids

    def _notify_children(self, obj=None, name=None, old=None, new=None):
        old_val = self._get_children_from_cache()
        self._remove_listners()
        self._create_children()
        new_val = self._get_children_from_cache()
        self.trait_property_changed('children', old_val, new_val)

    def _get_children(self):
        if not self.children_cache:
            self._create_children()
        kids = self._get_children_from_cache()
        tg = self.tree_generator
        return CompositeIterable(kids, tree_generator=tg)

    def _get_name(self):
        if isinstance(self.object, AVL):
            return 'pyAVL'
        elif isinstance(self.object, (Case, RunCase, Surface, Body)):
            return self.object.name
        elif isinstance(self.object, Geometry):
            pass
        return self.object.__class__.__name__

    ######################################################################
    # `TreeNodeObject` interface
    ######################################################################
    def tno_get_icon(self, node, is_expanded):
        """ Returns the icon for a specified object.
        """
        icon = get_icon(self.name)
        if icon:
            return icon
        else:
            return super(TreeBranchNode, self).tno_get_icon(node, is_expanded)
Exemple #26
0
class ActionItem(ActionManagerItem):
    """ An action manager item that represents an actual action. """

    #### 'ActionManagerItem' interface ########################################

    # The item's unique identifier ('unique' in this case means unique within
    # its group).
    id = Property(Str)

    #### 'ActionItem' interface ###############################################

    # The action!
    action = Instance(Action)

    # The toolkit specific control created for this item.
    control = Any

    # The toolkit specific Id of the control created for this item.
    #
    # We have to keep the Id as well as the control because wx tool bar tools
    # are created as 'wxObjectPtr's which do not have Ids, and the Id is
    # required to manipulate the state of a tool via the tool bar 8^(
    # FIXME v3: Why is this part of the public interface?
    control_id = Any

    #### Private interface ####################################################

    # All of the internal instances that wrap this item.
    _wrappers = List(Any)

    ###########################################################################
    # 'ActionManagerItem' interface.
    ###########################################################################

    #### Trait properties #####################################################

    def _get_id(self):
        """ Return's the item's Id. """

        return self.action.id

    #### Trait change handlers ################################################

    def _enabled_changed(self, trait_name, old, new):
        """ Static trait change handler. """

        self.action.enabled = new

        return

    def _visible_changed(self, trait_name, old, new):
        """ Static trait change handler. """

        self.action.visible = True

        return
    
    ###########################################################################
    # 'ActionItem' interface.
    ###########################################################################

    def add_to_menu(self, parent, menu, controller):
        """ Adds the item to a menu. """

        if (controller is None) or controller.can_add_to_menu(self.action): 
            wrapper = _MenuItem(parent, menu, self, controller)

            # fixme: Martin, who uses this information?
            if controller is None:
                self.control = wrapper.control
                self.control_id = wrapper.control_id

            self._wrappers.append(wrapper)

        return

    def add_to_toolbar(self, parent, tool_bar, image_cache, controller,
                       show_labels=True):
        """ Adds the item to a tool bar. """

        if (controller is None) or controller.can_add_to_toolbar(self.action): 
            wrapper = _Tool(
                parent, tool_bar, image_cache, self, controller, show_labels
            )
            
            # fixme: Martin, who uses this information?
            if controller is None:
                self.control = wrapper.control
                self.control_id = wrapper.control_id

            self._wrappers.append(wrapper)

        return

    def add_to_palette(self, tool_palette, image_cache, show_labels=True):
        """ Adds the item to a tool palette. """
        
        wrapper = _PaletteTool(tool_palette, image_cache, self, show_labels)

        self._wrappers.append(wrapper)

        return

    def destroy(self):
        """ Called when the action is no longer required.

        By default this method calls 'destroy' on the action itself.
        """

        self.action.destroy()
        
        return
Exemple #27
0
class Position(HasTraits):
    """ Simple object to act as a data structure for a position 
    
        While all attributes (traits) are optional, classes that contain or
        collect instances of the Position class will require the following:
        symbol, trans_date, qty, price, total_amt
    
    """
    
    
    side = Enum("BUYTOOPEN", ["SELLTOCLOS", "BUYTOOPEN", "SELLTOOPEN", "BUYTOCLOSE"])
    symbol = Str
    id = Int
    description = Str
    trans_date = Float
    qty = Float
    price = Float
    multiplier = Float(1.0)
    fee = Float
    exchange_rate = Float(1.0)
    currency = Str("USD")
    total_amt = Float
    filled = Str
    exchange = Str
    
    # The following traits are for viewing and editing the datetime value
    #     of trans_date (which is a float of seconds since the Epoch)
    date_display = Property(Regex(value='11/17/1969',
                                  regex='\d\d[/]\d\d[/]\d\d\d\d'),
                                  depends_on='trans_date')
    time_display = Property(Regex(value='12:01:01',
                                  regex='\d\d[:]\d\d[:]\d\d'),
                                  depends_on='trans_date')
    
    # specify default view layout
    traits_view = View(Item('symbol', label="Symb"),
                       Item('date_display'),
                       Item('time_display'),
                       Item('qty'),
                       buttons=['OK', 'Cancel'], resizable=True)
    
    ###################################
    # Property methods
    def _get_date_display(self):
        return dt_from_timestamp(self.trans_date, tz=Eastern).strftime("%m/%d/%Y")
        
    def _set_date_display(self, val):
        tm = self._get_time_display()
        trans_date = datetime.strptime(val+tm, "%m/%d/%Y%H:%M:%S" )
        trans_date = trans_date.replace(tzinfo=Eastern)
        self.trans_date = dt_to_timestamp(trans_date)
        return 
        
    def _get_time_display(self):
        t = dt_from_timestamp(self.trans_date, tz=Eastern).strftime("%H:%M:%S")
        return t
        
    def _set_time_display(self, val):
        trans_time = datetime.strptime(self._get_date_display()+val, "%m/%d/%Y%H:%M:%S")
        trans_time = trans_time.replace(tzinfo=Eastern)
        self.trans_date = dt_to_timestamp(trans_time)
        return

    ###################################
    # Override default class methods
    
    # cleaner, more reasonable representation of the object
    def __repr__(self):
        return "<Position %s %s>" % (self.symbol, self.qty)
    
    # support reasonable sorting based on trans_date
    def __cmp__(self, other):
        if self.trans_date < other.trans_date:
            return -1
        elif self.trans_date > other.trans_date:
            return 1
        else: return 0

#### EOF ####################################################################
Exemple #28
0
class Observation(HasTraits):
    """Observation of horizontal angle, zenith angle, and slope distance
       collected at a BaseSetup."""
    id = String
    base = Instance(BaseSetup, kw={'x': 0,
                                   'y': 0,
                                   'z': 0,
                                   'z_offset': 0})
    zenith_angle = Instance(AngleDMS, kw={'degrees': 90,
                                          'minutes': 0,
                                          'seconds': 0})
    horizontal_angle = Instance(AngleDMS, kw={'degrees': 0,
                                              'minutes': 0,
                                              'seconds': 0})
    z_offset = Float
    slope_distance = Float
    horizontal_distance = Property(depends_on='zenith_angle.radians, \
                                               horizontal_angle.radians, \
                                               slope_distance')
    vertical_distance = Property(depends_on='zenith_angle.radians, \
                                             horizontal_angle.radians, \
                                             slope_distance')
    x = Property(depends_on='base.x, \
                             base.horizontal_angle_offset.radians,\
                             zenith_angle.radians, \
                             horizontal_angle.radians, \
                             slope_distance, \
                             z_offset')
    y = Property(depends_on='base.y, \
                             base.horizontal_angle_offset.radians,\
                             zenith_angle.radians, \
                             horizontal_angle.radians, \
                             slope_distance, \
                             z_offset')
    z = Property(depends_on='base.z, \
                             base.z_offset, \
                             base.horizontal_angle_offset.radians,\
                             zenith_angle.radians, \
                             horizontal_angle.radians, \
                             slope_distance, \
                             z_offset')
    
    @cached_property
    def _get_horizontal_distance(self):
        return self.slope_distance * sin(self.zenith_angle.radians)
    
    @cached_property
    def _get_vertical_distance(self):
        return self.slope_distance * cos(self.zenith_angle.radians)
    
    @cached_property
    def _get_x(self):
        return sin(self.horizontal_angle.radians
                   + self.base.horizontal_angle_offset.radians) \
               * self.horizontal_distance + self.base.x  
        
    @cached_property
    def _get_y(self):
        return cos(self.horizontal_angle.radians
                   + self.base.horizontal_angle_offset.radians) \
               * self.horizontal_distance + self.base.y
               
    @cached_property
    def _get_z(self):
        return self.slope_distance * cos(self.zenith_angle.radians) \
               + self.base.z \
               + self.base.z_offset - self.z_offset
    
    view = View(Item('base', style='custom', show_label=False),
                Group(Item('horizontal_angle', style='custom'),
                      Item('zenith_angle', style='custom'),
                      Item('slope_distance'),
                      Item('z_offset'),
                      label='Observation',
                      show_border=True),
                HGroup(Item('x',
                            format_str='%.3f',
                            springy = True),
                       Item('y',
                            format_str='%.3f',
                            springy = True),
                       Item('z',
                            format_str='%.3f',
                            springy = True),
                       label='Reduced coordinates',
                       show_border=True),
                buttons=LiveButtons)
Exemple #29
0
class OrientationAxes(Module):
    # The version of this class.  Used for persistence.
    __version__ = 0

    # The tvtk orientation marker widget.
    marker = Instance(tvtk.OrientationMarkerWidget, allow_none=False)

    # The tvtk axes that will be shown.
    axes = Instance(tvtk.AxesActor, allow_none=False, record=True)

    # The property of the axes (color etc.).
    text_property = Property(record=True)

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['any'])    

    ########################################
    # Private traits.
    _text_property = Instance(tvtk.TextProperty)
    
    ########################################
    # The view of this object.

    _marker_group = Group(Item(name='enabled'),
                          Item(name='interactive'),
                          show_border=True,
                          label='Widget')
    _axes_group = Group(Item(name='axis_labels'),
                        Item(name='visibility'),
                        Item(name='x_axis_label_text'),
                        Item(name='y_axis_label_text'),
                        Item(name='z_axis_label_text'),
                        Item(name='cone_radius'),
                        Item(name='cone_resolution'),
                        Item(name='cylinder_radius'),
                        Item(name='cylinder_resolution'),
                        Item(name='normalized_label_position'),
                        Item(name='normalized_shaft_length'),
                        Item(name='normalized_tip_length'),
                        Item(name='total_length'),
                        show_border=True,
                        label='Axes')

    view = View(Group(Item(name='marker', style='custom',
                           editor=InstanceEditor(view=View(_marker_group))),
                      Item(name='axes', style='custom',
                           editor=InstanceEditor(view=View(_axes_group))),
                      label='Widget/Axes',
                      show_labels=False),
                Group(Item(name='_text_property', style='custom',
                           resizable=True),
                      label='Text Property',
                      show_labels=False),
                )

    ######################################################################
    # `object` interface
    ######################################################################
    def __set_pure_state__(self, state):
        for prop in ['axes', 'marker', '_text_property']:
            obj = getattr(self, prop)
            state_pickler.set_state(obj, state[prop])
            

    ######################################################################
    # `Module` interface
    ######################################################################
    def setup_pipeline(self):
        """Override this method so that it *creates* the tvtk
        pipeline.

        This method is invoked when the object is initialized via
        `__init__`.  Note that at the time this method is called, the
        tvtk data pipeline will *not* yet be setup.  So upstream data
        will not be available.  The idea is that you simply create the
        basic objects and setup those parts of the pipeline not
        dependent on upstream sources and filters.  You should also
        set the `actors` attribute up at this point.
        """
        # Setup the default objects.
        self.axes = tvtk.AxesActor(normalized_tip_length=(0.4, 0.4, 0.4),
                                   normalized_shaft_length=(0.6, 0.6, 0.6),
                                   shaft_type='cylinder')
        self.text_property.set(color=(1,1,1), shadow=False, italic=False)
        
        self.marker = tvtk.OrientationMarkerWidget(key_press_activation=False)

    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.

        This method is invoked (automatically) when any of the inputs
        sends a `pipeline_changed` event.
        """
        self.pipeline_changed = True

    def update_data(self):
        """Override this method so that it flushes the vtk pipeline if
        that is necessary.

        This method is invoked (automatically) when any of the inputs
        sends a `data_changed` event.
        """
        # Just set data_changed, the component should do the rest.
        self.data_changed = True

    ######################################################################
    # Non-public interface
    ######################################################################
    def _marker_changed(self, old, new):
        if old is not None:
            old.on_trait_change(self.render, remove=True)
            self.widgets.remove(old)
        axes = self.axes
        if axes is not None:
            new.orientation_marker = axes            
        new.on_trait_change(self.render)

        self.widgets.append(new)
        self.render()

    def _axes_changed(self, old, new):
        if old is not None:
            old.on_trait_change(self.render, remove=True)
            self._text_property.on_trait_change(self.render, remove=True)
        marker = self.marker
        if marker is not None:
            marker.orientation_marker = new

        p = new.x_axis_caption_actor2d.caption_text_property
        new.y_axis_caption_actor2d.caption_text_property = p
        new.z_axis_caption_actor2d.caption_text_property = p
        self._text_property = p

        # XXX: The line of code below is a stop-gap solution. Without it, 
        # Some observers in the AxesActor trigger a modification of the
        # font_size each time the mouse is moved over the OrientationAxes
        # (this can be seen when running the record mode, for instance),
        # and thus a render, which is very slow. On the other hand, font
        # size does not work for the AxesActor, with or without the
        # line of code below. So we probably haven't found the true
        # cause of the problem.
        p.teardown_observers()
            
        new.on_trait_change(self.render)
        p.on_trait_change(self.render)
        
        self.render()
        
    def _get_text_property(self):
        return self._text_property
    
    def _foreground_changed_for_scene(self, old, new):
        # Change the default color for the actor.
        self.text_property.color = new
        self.render()
        
    def _scene_changed(self, old, new):
        super(OrientationAxes, self)._scene_changed(old, new)
        self._foreground_changed_for_scene(None, new.foreground)
        self._visible_changed(self.visible)

    def _visible_changed(self, value):
        if self.scene is not None and self.marker.interactor:
            # Enabling an OrientationAxes without an interactor will
            # lead to a segfault
            super(OrientationAxes, self)._visible_changed(value)
Exemple #30
0
class Threshold(Filter):

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The threshold filter used.
    threshold_filter = Property(Instance(tvtk.Object, allow_none=False),
                                record=True)

    # The filter type to use, specifies if the cells or the points are
    # cells filtered via a threshold.
    filter_type = Enum('cells',
                       'points',
                       desc='if thresholding is done on cells or points')

    # Lower threshold (this is a dynamic trait that is changed when
    # input data changes).
    lower_threshold = Range(value=-1.0e20,
                            low='_data_min',
                            high='_data_max',
                            enter_set=True,
                            auto_set=False,
                            desc='the lower threshold of the filter')

    # Upper threshold (this is a dynamic trait that is changed when
    # input data changes).
    upper_threshold = Range(value=1.0e20,
                            low='_data_min',
                            high='_data_max',
                            enter_set=True,
                            auto_set=False,
                            desc='the upper threshold of the filter')

    # Automatically reset the lower threshold when the upstream data
    # changes.
    auto_reset_lower = Bool(True,
                            desc='if the lower threshold is '
                            'automatically reset when upstream '
                            'data changes')

    # Automatically reset the upper threshold when the upstream data
    # changes.
    auto_reset_upper = Bool(True,
                            desc='if the upper threshold is '
                            'automatically reset when upstream '
                            'data changes')

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['poly_data', 'unstructured_grid'],
                               attribute_types=['any'],
                               attributes=['any'])

    # Our view.
    view = View(Group(
        Group(Item(name='filter_type'), Item(name='lower_threshold'),
              Item(name='auto_reset_lower'), Item(name='upper_threshold'),
              Item(name='auto_reset_upper')),
        Item(name='_'),
        Group(
            Item(name='threshold_filter',
                 show_label=False,
                 visible_when='object.filter_type == "cells"',
                 style='custom',
                 resizable=True)),
    ),
                resizable=True)

    ########################################
    # Private traits.

    # These traits are used to set the limits for the thresholding.
    # They store the minimum and maximum values of the input data.
    _data_min = Float(-1e20)
    _data_max = Float(1e20)

    # The threshold filter for cell based filtering
    _threshold = Instance(tvtk.Threshold, args=(), allow_none=False)

    # The threshold filter for points based filtering.
    _threshold_points = Instance(tvtk.ThresholdPoints,
                                 args=(),
                                 allow_none=False)

    # Internal data to
    _first = Bool(True)

    ######################################################################
    # `object` interface.
    ######################################################################
    def __get_pure_state__(self):
        d = super(Threshold, self).__get_pure_state__()
        # These traits are dynamically created.
        for name in ('_first', '_data_min', '_data_max'):
            d.pop(name, None)

        return d

    ######################################################################
    # `Filter` interface.
    ######################################################################
    def setup_pipeline(self):
        attrs = [
            'all_scalars', 'attribute_mode', 'component_mode',
            'selected_component'
        ]
        self._threshold.on_trait_change(self._threshold_filter_edited, attrs)

    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.

        This method is invoked (automatically) when the input fires a
        `pipeline_changed` event.
        """
        if len(self.inputs) == 0:
            return

        # By default we set the input to the first output of the first
        # input.
        fil = self.threshold_filter
        fil.input = self.inputs[0].outputs[0]

        self._update_ranges()
        self._set_outputs([self.threshold_filter.output])

    def update_data(self):
        """Override this method to do what is necessary when upstream
        data changes.

        This method is invoked (automatically) when any of the inputs
        sends a `data_changed` event.
        """
        if len(self.inputs) == 0:
            return

        self._update_ranges()

        # Propagate the data_changed event.
        self.data_changed = True

    ######################################################################
    # Non-public interface
    ######################################################################
    def _lower_threshold_changed(self, new_value):
        fil = self.threshold_filter
        fil.threshold_between(new_value, self.upper_threshold)
        fil.update()
        self.data_changed = True

    def _upper_threshold_changed(self, new_value):
        fil = self.threshold_filter
        fil.threshold_between(self.lower_threshold, new_value)
        fil.update()
        self.data_changed = True

    def _update_ranges(self):
        """Updates the ranges of the input.
        """
        data_range = self._get_data_range()
        if len(data_range) > 0:
            dr = data_range
            if self._first:
                self._data_min, self._data_max = dr
                self.set(lower_threshold=dr[0], trait_change_notify=False)
                self.upper_threshold = dr[1]
                self._first = False
            else:
                if self.auto_reset_lower:
                    self._data_min = dr[0]
                    notify = not self.auto_reset_upper
                    self.set(lower_threshold=dr[0], trait_change_notify=notify)
                if self.auto_reset_upper:
                    self._data_max = dr[1]
                    self.upper_threshold = dr[1]

    def _get_data_range(self):
        """Returns the range of the input scalar data."""
        input = self.inputs[0].outputs[0]
        data_range = []
        ps = input.point_data.scalars
        cs = input.cell_data.scalars

        # FIXME: need to be able to handle cell and point data
        # together.
        if ps is not None:
            data_range = list(ps.range)
            if np.isnan(data_range[0]):
                data_range[0] = float(np.nanmin(ps.to_array()))
            if np.isnan(data_range[1]):
                data_range[1] = float(np.nanmax(ps.to_array()))
        elif cs is not None:
            data_range = cs.range
            if np.isnan(data_range[0]):
                data_range[0] = float(np.nanmin(cs.to_array()))
            if np.isnan(data_range[1]):
                data_range[1] = float(np.nanmax(cs.to_array()))
        return data_range

    def _auto_reset_lower_changed(self, value):
        if len(self.inputs) == 0:
            return
        if value:
            dr = self._get_data_range()
            self._data_min = dr[0]
            self.lower_threshold = dr[0]

    def _auto_reset_upper_changed(self, value):
        if len(self.inputs) == 0:
            return
        if value:
            dr = self._get_data_range()
            self._data_max = dr[1]
            self.upper_threshold = dr[1]

    def _get_threshold_filter(self):
        if self.filter_type == 'cells':
            return self._threshold
        else:
            return self._threshold_points

    def _filter_type_changed(self, value):
        if value == 'cells':
            old = self._threshold_points
            new = self._threshold
        else:
            old = self._threshold
            new = self._threshold_points
        self.trait_property_changed('threshold_filter', old, new)

    def _threshold_filter_changed(self, old, new):
        if len(self.inputs) == 0:
            return
        fil = new
        fil.input = self.inputs[0].outputs[0]
        fil.threshold_between(self.lower_threshold, self.upper_threshold)
        fil.update()
        self._set_outputs([fil.output])

    def _threshold_filter_edited(self):
        self.threshold_filter.update()
        self.data_changed = True