Esempio n. 1
0
    def __init__(self, default_value=None, values=(), iotype=None, 
                        aliases=(), desc=None, **metadata):

        # Allow some variant constructors (no default, no index)
        if not values:
            if default_value is None:
                raise ValueError("Enum must contain at least one value.")
            else:
                values = default_value
                if isinstance(values, (tuple, list)):
                    default_value = values[0]
        else:
            if default_value is None:
                default_value = values[0]

        # We need tuples or a list for the index
        if not isinstance(values, (tuple, list)):
            values = (values,)
                
        if aliases:
            
            if not isinstance(aliases, (tuple, list)):
                aliases = (aliases,)
                
            if len(aliases) != len(values):
                raise ValueError("Length of aliases does not match " + \
                                 "length of values.")
            
        if default_value not in values:
            raise ValueError("Default value not in values.")
            
        self._validator = TraitEnum(default_value, values, **metadata)
            
        # Put iotype in the metadata dictionary
        if iotype is not None:
            metadata['iotype'] = iotype
            
        # Put desc in the metadata dictionary
        if desc is not None:
            metadata['desc'] = desc
            
        # Put values in the metadata dictionary
        if values:
            metadata['values'] = values
            
            # We also need to store the values in a dict, to get around
            # a weak typechecking (i.e., enum of [1,2,3] can be 1.0)
            self.valuedict = {}
            
            for val in values:
                self.valuedict[val] = val

        # Put aliases in the metadata dictionary
        if aliases:
            metadata['aliases'] = aliases

        super(Enum, self).__init__(default_value=default_value,
                                         **metadata)
Esempio n. 2
0
class MATS2D5PlasticBond(MATS2DEval):
    '''
    Elastic Model.
    '''

    implements(IMATSEval)

    #---------------------------------------------------------------------------
    # Parameters of the numerical algorithm (integration)
    #---------------------------------------------------------------------------

    stress_state = Enum("plane_stress", "plane_strain")

    #---------------------------------------------------------------------------
    # Material parameters
    #---------------------------------------------------------------------------

    E_m = Float(
        1.,  #34e+3,
        label="E_m",
        desc="Young's Modulus",
        auto_set=False)
    nu_m = Float(0.2, label='nu_m', desc="Poison's ratio", auto_set=False)

    E_f = Float(
        1.,  #34e+3,
        label="E_f",
        desc="Young's Modulus",
        auto_set=False)
    nu_f = Float(0.2, label='nu_f', desc="Poison's ratio", auto_set=False)

    G = Float(
        1.,  #34e+3,
        label="G",
        desc="Shear Modulus",
        auto_set=False)

    sigma_y = Float(
        .5,  #34e+3,
        label="s_y",
        desc="Yield stress",
        auto_set=False)

    K_bar = Float(
        0.,  #34e+3,
        label="K",
        desc="isotropic hardening",
        auto_set=False)

    H_bar = Float(
        0.,  #34e+3,
        label="H",
        desc="kinematic hardening",
        auto_set=False)

    D_el = Property(Array(float),
                    depends_on='E_f, nu_f,E_m,nu_f,G, stress_state')

    @cached_property
    def _get_D_el(self):
        if self.stress_state == "plane_stress":
            return self._get_D_plane_stress()
        else:
            return self._get_D_plane_strain()

    # This event can be used by the clients to trigger an action upon
    # the completed reconfiguration of the material model
    #
    changed = Event

    #---------------------------------------------------------------------------------------------
    # View specification
    #---------------------------------------------------------------------------------------------

    view_traits = View(VSplit(
        Group(Item('E_m'), Item('nu_m'), Item('E_f'), Item('nu_f'), Item('G')),
        Group(
            Item('stress_state', style='custom'),
            Spring(resizable=True),
            label='Configuration parameters',
            show_border=True,
        ),
    ),
                       resizable=True)

    #-----------------------------------------------------------------------------------------------
    # Private initialization methods
    #-----------------------------------------------------------------------------------------------

    #-----------------------------------------------------------------------------------------------
    # Setup for computation within a supplied spatial context
    #-----------------------------------------------------------------------------------------------

    def new_cntl_var(self):
        return zeros(3, float_)

    def new_resp_var(self):
        return zeros(3, float_)

    def get_state_array_size(self):
        return 3

    #-----------------------------------------------------------------------------------------------
    # Evaluation - get the corrector and predictor
    #-----------------------------------------------------------------------------------------------

    def get_corr_pred(self, sctx, eps_app_eng, d_eps, tn, tn1):
        '''
        Corrector predictor computation.
        @param eps_app_eng input variable - engineering strain
        '''
        sigma = dot(self.D_el[:], eps_app_eng)

        # You print the stress you just computed and the value of the apparent E
        eps_n1 = float(eps_app_eng[6])  #hack for this particular case
        G = self.G

        eps_avg = eps_n1

        if sctx.update_state_on:
            eps_n = eps_avg - float(d_eps[6])
            sctx.mats_state_array[:] = self._get_state_variables(sctx, eps_n)

        #print 'state array ', sctx.mats_state_array
        eps_p_n, q_n, alpha_n = sctx.mats_state_array
        sigma_trial = self.G * (eps_n1 - eps_p_n)
        xi_trial = sigma_trial - q_n
        f_trial = abs(xi_trial) - (self.sigma_y + self.K_bar * alpha_n)
        #f_trial = -xi_trial - ( self.sigma_y + self.K_bar * alpha_n )

        sig_n1 = zeros((1, ), dtype='float_')
        D_n1 = zeros((1, 1), dtype='float_')
        if f_trial <= 1e-8:
            sig_n1[0] = sigma_trial
            D_n1[0, 0] = G
        else:
            #print 'plastic'
            d_gamma = f_trial / (self.G + self.K_bar + self.H_bar)
            sig_n1[0] = sigma_trial - d_gamma * self.G * sign(xi_trial)
            D_n1[0, 0] = (self.G * (self.K_bar + self.H_bar)) / \
                            (self.G + self.K_bar + self.H_bar)
            #print 'stress ', sig_n1[0]
        sigma[6] = sig_n1[0]
        self.D_el[6, 6] = D_n1[0, 0]
        return sigma, self.D_el

    #---------------------------------------------------------------------------------------------
    # Subsidiary methods realizing configurable features
    #---------------------------------------------------------------------------------------------

    def _get_state_variables(self, sctx, eps_n):

        eps_p_n, q_n, alpha_n = sctx.mats_state_array

        # Get the characteristics of the trial step
        #
        sig_trial = self.G * (eps_n - eps_p_n)
        xi_trial = sig_trial - q_n
        f_trial = abs(xi_trial) - (self.sigma_y + self.K_bar * alpha_n)

        if f_trial > 1e-8:

            #
            # Tha last equilibrated step was inelastic. Here the
            # corresponding state variables must be calculated once
            # again. This might be expensive for 2D and 3D models. Then,
            # some kind of caching should be considered for the state
            # variables determined during iteration. In particular, the
            # computation of d_gamma should be outsourced into a separate
            # method that can in general perform an iterative computation.
            #
            d_gamma = f_trial / (self.G + self.K_bar + self.H_bar)
            eps_p_n += d_gamma * sign(xi_trial)
            q_n += d_gamma * self.H_bar * sign(xi_trial)
            alpha_n += d_gamma

        newarr = array([eps_p_n, q_n, alpha_n], dtype='float_')

        return newarr

    def _get_D_plane_stress(self):
        E_m = self.E_m
        nu_m = self.nu_m
        E_f = self.E_f
        nu_f = self.nu_f
        G = self.G
        D_stress = zeros([8, 8])
        D_stress[0, 0] = E_m / (1.0 - nu_m * nu_m)
        D_stress[0, 1] = E_m / (1.0 - nu_m * nu_m) * nu_m
        D_stress[1, 0] = E_m / (1.0 - nu_m * nu_m) * nu_m
        D_stress[1, 1] = E_m / (1.0 - nu_m * nu_m)
        D_stress[2, 2] = E_m / (1.0 - nu_m * nu_m) * (1.0 / 2.0 - nu_m / 2.0)

        D_stress[3, 3] = E_f / (1.0 - nu_f * nu_f)
        D_stress[3, 4] = E_f / (1.0 - nu_f * nu_f) * nu_f
        D_stress[4, 3] = E_f / (1.0 - nu_f * nu_f) * nu_f
        D_stress[4, 4] = E_f / (1.0 - nu_f * nu_f)
        D_stress[5, 5] = E_f / (1.0 - nu_f * nu_f) * (1.0 / 2.0 - nu_f / 2.0)

        D_stress[6, 6] = G
        D_stress[7, 7] = G
        return D_stress

    def _get_D_plane_strain(self):
        #TODO: adapt to use arbitrary 2d model following the 1d5 bond
        E_m = self.E_m
        nu_m = self.nu_m
        E_f = self.E_f
        nu_f = self.nu_f
        G = self.G
        D_strain = zeros([8, 8])
        D_strain[0, 0] = E_m * (1.0 - nu_m) / (1.0 + nu_m) / (1.0 - 2.0 * nu_m)
        D_strain[0, 1] = E_m / (1.0 + nu_m) / (1.0 - 2.0 * nu_m) * nu_m
        D_strain[1, 0] = E_m / (1.0 + nu_m) / (1.0 - 2.0 * nu_m) * nu_m
        D_strain[1, 1] = E_m * (1.0 - nu_m) / (1.0 + nu_m) / (1.0 - 2.0 * nu_m)
        D_strain[2, 2] = E_m * (1.0 - nu_m) / (1.0 + nu_m) / (2.0 - 2.0 * nu_m)

        D_strain[3, 3] = E_f * (1.0 - nu_f) / (1.0 + nu_f) / (1.0 - 2.0 * nu_f)
        D_strain[3, 4] = E_f / (1.0 + nu_f) / (1.0 - 2.0 * nu_f) * nu_f
        D_strain[4, 3] = E_f / (1.0 + nu_f) / (1.0 - 2.0 * nu_f) * nu_f
        D_strain[4, 4] = E_f * (1.0 - nu_f) / (1.0 + nu_f) / (1.0 - 2.0 * nu_f)
        D_strain[5, 5] = E_f * (1.0 - nu_f) / (1.0 + nu_f) / (2.0 - 2.0 * nu_f)

        D_strain[6, 6] = G
        D_strain[7, 7] = G
        return D_strain

    #---------------------------------------------------------------------------------------------
    # Response trace evaluators
    #---------------------------------------------------------------------------------------------

    def get_sig_norm(self, sctx, eps_app_eng):
        sig_eng, D_mtx = self.get_corr_pred(sctx, eps_app_eng, 0, 0, 0)
        return array([scalar_sqrt(sig_eng[0]**2 + sig_eng[1]**2)])

    def get_eps_app_m(self, sctx, eps_app_eng):
        return self.map_eps_eng_to_mtx((eps_app_eng[:3]))

    def get_eps_app_f(self, sctx, eps_app_eng):
        return self.map_eps_eng_to_mtx((eps_app_eng[3:6]))

    def get_sig_app_m(self, sctx, eps_app_eng, *args, **kw):
        sig_eng, D_mtx = self.get_corr_pred(sctx, eps_app_eng, 0, 0, 0)
        return self.map_sig_eng_to_mtx((sig_eng[:3]))

    def get_sig_app_f(self, sctx, eps_app_eng, *args, **kw):
        sig_eng, D_mtx = self.get_corr_pred(sctx, eps_app_eng, 0, 0, 0)
        return self.map_sig_eng_to_mtx((sig_eng[3:6]))

    def get_sig_b(self, sctx, eps_app_eng, *args, **kw):
        sig_eng, D_mtx = self.get_corr_pred(sctx, eps_app_eng, 0, 0, 0)
        return array([[sig_eng[6], 0.], [0., sig_eng[7]]])

    # Declare and fill-in the rte_dict - it is used by the clients to
    # assemble all the available time-steppers.
    #
    rte_dict = Trait(Dict)

    def _rte_dict_default(self):
        return {
            'eps_app_f': self.get_eps_app_f,
            'eps_app_m': self.get_eps_app_m,
            'sig_app_f': self.get_sig_app_f,
            'sig_app_m': self.get_sig_app_m,
            'sig_norm': self.get_sig_norm,
            'sig_b': self.get_sig_b,
        }
Esempio n. 3
0
# Thanks for using Enthought open source!

# -------------------------------------------------------------------------
#  Imports:
# -------------------------------------------------------------------------

from traits.api import HasTraits, Trait, Enum, Range

from traitsui.api import View, Item, EnumEditor

# -------------------------------------------------------------------------
#  Trait definitions:
# -------------------------------------------------------------------------

values = ['one', 'two', 'three', 'four']
enum = Enum(*values)
range = Range(1, 4)

# -------------------------------------------------------------------------
#  'TestEnumEditor' class:
# -------------------------------------------------------------------------


class TestEnumEditor(HasTraits):

    # -------------------------------------------------------------------------
    #  Trait definitions:
    # -------------------------------------------------------------------------
    value = Trait(1, enum, range)

    other_value = Range(0, 4)
Esempio n. 4
0
class TitleEditorDemo(HasTraits):

    # Define the selection of titles that can be displayed:
    title = Enum(
        'Select a new title from the drop down list below',
        'This is the TitleEditor demonstration',
        'Acme Widgets Sales for Each Quarter',
        'This is Not Intended to be a Real Application'
    )

    # A user settable version of the title:
    title_2 = Str('Type into the text field below to change this title')

    # A title driven by the result of a calculation:
    title_3 = Property(depends_on='value')

    # The number used to drive the calculation:
    value = Float

    # Define the test view:
    view = View(
        VGroup(
            VGroup(
                HGroup(
                    Item('title',
                         show_label=False,
                         springy=True,
                         editor=TitleEditor()
                         )
                ),
                Item('title'),
                show_border=True
            ),
            VGroup(
                HGroup(
                    Item('title_2',
                         show_label=False,
                         springy=True,
                         editor=TitleEditor()
                         )
                ),
                Item('title_2', label='Title'),
                show_border=True
            ),
            VGroup(
                HGroup(
                    Item('title_3',
                         show_label=False,
                         springy=True,
                         editor=TitleEditor()
                         )
                ),
                Item('value'),
                show_border=True
            )
        ),
        width=0.4
    )

    #-- Property Implementations ---------------------------------------------

    @cached_property
    def _get_title_3(self):
        try:
            return ('The square root of %s is %s' %
                    (self.value, self.value ** 0.5))
        except:
            return ('The square root of %s is %si' %
                    (self.value, (-self.value) ** 0.5))
Esempio n. 5
0
class Node(Controller):
    """ Basic Node structure of the pipeline that need to be tuned.

    Attributes
    ----------
    name : str
        the node name
    full_name : str
        a unique name among all nodes and sub-nodes of the top level pipeline
    enabled : bool
        user parameter to control the node activation
    activated : bool
        parameter describing the node status

    Methods
    -------
    connect
    set_callback_on_plug
    get_plug_value
    set_plug_value
    get_trait
    """
    name = Str()
    enabled = Bool(default_value=True)
    activated = Bool(default_value=False)
    node_type = Enum(("processing_node", "view_node"))

    def __init__(self, pipeline, name, inputs, outputs):
        """ Generate a Node

        Parameters
        ----------
        pipeline: Pipeline (mandatory)
            the pipeline object where the node is added
        name: str (mandatory)
            the node name
        inputs: list of dict (mandatory)
            a list of input parameters containing a dictionary with default
            values (mandatory key: name)
        outputs: dict (mandatory)
            a list of output parameters containing a dictionary with default
            values (mandatory key: name)
        """
        super(Node, self).__init__()
        self.pipeline = weak_proxy(pipeline)
        self.name = name
        self.plugs = SortedDictionary()
        # _callbacks -> (src_plug_name, dest_node, dest_plug_name)
        self._callbacks = {}

        # generate a list with all the inputs and outputs
        # the second parameter (parameter_type) is False for an input,
        # True for an output
        parameters = list(zip(inputs, [
            False,
        ] * len(inputs)))
        parameters.extend(list(zip(outputs, [
            True,
        ] * len(outputs))))
        for parameter, parameter_type in parameters:
            # check if parameter is a dictionary as specified in the
            # docstring
            if isinstance(parameter, dict):
                # check if parameter contains a name item
                # as specified in the docstring
                if "name" not in parameter:
                    raise Exception(
                        "Can't create parameter with unknown"
                        "identifier and parameter {0}".format(parameter))
                parameter = parameter.copy()
                plug_name = parameter.pop("name")
                # force the parameter type
                parameter["output"] = parameter_type
                # generate plug with input parameter and identifier name
                plug = Plug(**parameter)
            else:
                raise Exception("Can't create Node. Expect a dict structure "
                                "to initialize the Node, "
                                "got {0}: {1}".format(type(parameter),
                                                      parameter))
            # update plugs list
            self.plugs[plug_name] = plug
            # add an event on plug to validate the pipeline
            plug.on_trait_change(pipeline.update_nodes_and_plugs_activation,
                                 "enabled")

        # add an event on the Node instance traits to validate the pipeline
        self.on_trait_change(pipeline.update_nodes_and_plugs_activation,
                             "enabled")

    @property
    def process(self):
        return get_ref(self._process)

    @process.setter
    def process(self, value):
        self._process = value

    @property
    def full_name(self):
        if self.pipeline.parent_pipeline:
            return self.pipeline.pipeline_node.full_name + '.' + self.name
        else:
            return self.name

    @staticmethod
    def _value_callback(self, source_plug_name, dest_node, dest_plug_name,
                        value):
        """ Spread the source plug value to the destination plug.
        """
        try:
            dest_node.set_plug_value(dest_plug_name, value)
        except traits.TraitError:
            pass

    def _value_callback_with_logging(self, log_stream, prefix,
                                     source_plug_name, dest_node,
                                     dest_plug_name, value):
        """ Spread the source plug value to the destination plug, and log it in
        a stream for debugging.
        """
        #print '(debug) value changed:', self, self.name, source_plug_name, dest_node, dest_plug_name, repr(value), ', stream:', log_stream, prefix

        plug = self.plugs.get(source_plug_name, None)
        if plug is None:
            return

        def _link_name(dest_node, plug, prefix, dest_plug_name,
                       source_node_or_process):
            external = True
            sibling = False
            # check if it is an external link: if source is not a parent of dest
            if hasattr(source_node_or_process, 'process') \
                    and hasattr(source_node_or_process.process, 'nodes'):
                source_process = source_node_or_process
                source_node = source_node_or_process.process.pipeline_node
                children = [
                    x for k, x in source_node.process.nodes.items() if x != ''
                ]
                if dest_node in children:
                    external = False
            # check if it is a sibling node:
            # if external and source is not in dest
            if external:
                sibling = True
                #print >> open('/tmp/linklog.txt', 'a'), 'check sibling, prefix:', prefix, 'source:', source_node_or_process, ', dest_plug_name:', dest_plug_name, 'dest_node:', dest_node, dest_node.name
                if hasattr(dest_node, 'process') \
                        and hasattr(dest_node.process, 'nodes'):
                    children = [
                        x for k, x in dest_node.process.nodes.items()
                        if x != ''
                    ]
                    if source_node_or_process in children:
                        sibling = False
                    else:
                        children = [
                            x.process for x in children \
                            if hasattr(x, 'process')]
                    if source_node_or_process in children:
                        sibling = False
                #print 'sibling:', sibling
            if external:
                if sibling:
                    name = '.'.join(prefix.split('.')[:-2] \
                        + [dest_node.name, dest_plug_name])
                else:
                    name = '.'.join(prefix.split('.')[:-2] + [dest_plug_name])
            else:
                # internal connection in a (sub) pipeline
                name = prefix + dest_node.name
                if name != '' and not name.endswith('.'):
                    name += '.'
                name += dest_plug_name
            return name

        dest_plug = dest_node.plugs[dest_plug_name]
        #print >> open('/tmp/linklog.txt', 'a'), 'link_name:',  self, repr(self.name), ', prefix:', repr(prefix), ', source_plug_name:', source_plug_name, 'dest:', dest_plug, repr(dest_plug_name), 'dest node:', dest_node, repr(dest_node.name)
        print('value link:', \
            'from:', prefix + source_plug_name, \
            'to:', _link_name(dest_node, dest_plug, prefix, dest_plug_name,
                              self), \
            ', value:', repr(value), file=log_stream) #, 'self:', self, repr(self.name), ', prefix:',repr(prefix), ', source_plug_name:', source_plug_name, 'dest:', dest_plug, repr(dest_plug_name), 'dest node:', dest_node, repr(dest_node.name)
        log_stream.flush()

        # actually propagate
        dest_node.set_plug_value(dest_plug_name, value)

    def connect(self, source_plug_name, dest_node, dest_plug_name):
        """ Connect linked plugs of two nodes

        Parameters
        ----------
        source_plug_name: str (mandatory)
            the source plug name
        dest_node: Node (mandatory)
            the destination node
        dest_plug_name: str (mandatory)
            the destination plug name
        """
        # add a callback to spread the source plug value
        value_callback = SomaPartial(self.__class__._value_callback,
                                     weak_proxy(self), source_plug_name,
                                     weak_proxy(dest_node), dest_plug_name)
        self._callbacks[(source_plug_name, dest_node,
                         dest_plug_name)] = value_callback
        self.set_callback_on_plug(source_plug_name, value_callback)

    def disconnect(self, source_plug_name, dest_node, dest_plug_name):
        """ disconnect linked plugs of two nodes

        Parameters
        ----------
        source_plug_name: str (mandatory)
            the source plug name
        dest_node: Node (mandatory)
            the destination node
        dest_plug_name: str (mandatory)
            the destination plug name
        """
        # remove the callback to spread the source plug value
        callback = self._callbacks.pop(
            (source_plug_name, dest_node, dest_plug_name))
        self.remove_callback_from_plug(source_plug_name, callback)

    def __getstate__(self):
        """ Remove the callbacks from the default __getstate__ result because
        they prevent Node instance from being used with pickle.
        """
        state = super(Node, self).__getstate__()
        state['_callbacks'] = state['_callbacks'].keys()
        state['pipeline'] = get_ref(state['pipeline'])
        return state

    def __setstate__(self, state):
        """ Restore the callbacks that have been removed by __getstate__.
        """
        state['_callbacks'] = dict((i, SomaPartial(self._value_callback, *i))
                                   for i in state['_callbacks'])
        if state['pipeline'] is state['process']:
            state['pipeline'] = state['process'] = weak_proxy(
                state['pipeline'])
        else:
            state['pipeline'] = weak_proxy(state['pipeline'])
        super(Node, self).__setstate__(state)
        for callback_key, value_callback in six.iteritems(self._callbacks):
            self.set_callback_on_plug(callback_key[0], value_callback)

    def set_callback_on_plug(self, plug_name, callback):
        """ Add an event when a plug change

        Parameters
        ----------
        plug_name: str (mandatory)
            a plug name
        callback: @f (mandatory)
            a callback function
        """
        self.on_trait_change(callback, plug_name)

    def remove_callback_from_plug(self, plug_name, callback):
        """ Remove an event when a plug change

        Parameters
        ----------
        plug_name: str (mandatory)
            a plug name
        callback: @f (mandatory)
            a callback function
        """
        self.on_trait_change(callback, plug_name, remove=True)

    def get_plug_value(self, plug_name):
        """ Return the plug value

        Parameters
        ----------
        plug_name: str (mandatory)
            a plug name

        Returns
        -------
        output: object
            the plug value
        """
        return getattr(self, plug_name)

    def set_plug_value(self, plug_name, value):
        """ Set the plug value

        Parameters
        ----------
        plug_name: str (mandatory)
            a plug name
        value: object (mandatory)
            the plug value we want to set
        """
        setattr(self, plug_name, value)

    def get_trait(self, trait_name):
        """ Return the desired trait

        Parameters
        ----------
        trait_name: str (mandatory)
            a trait name

        Returns
        -------
        output: trait
            the trait named trait_name
        """
        return self.trait(trait_name)
Esempio n. 6
0
class RangeSelection(AbstractController):
    """ Selects a range along the index or value axis.

    The user right-click-drags to select a region, which stays selected until
    the user left-clicks to deselect.
    """

    # The axis to which this tool is perpendicular.
    axis = Enum("index", "value")

    # The selected region, expressed as a tuple in data space.  This updates
    # and fires change-events as the user is dragging.
    selection = Property

    selection_mode = Enum("set", "append")

    # This event is fired whenever the user completes the selection, or when a
    # finalized selection gets modified.  The value of the event is the data
    # space range.
    selection_completed = Event

    # The name of the metadata on the datasource that we will write
    # self.selection to
    metadata_name = Str("selections")

    # Either "set" or "append", depending on whether self.append_key was
    # held down
    selection_mode_metadata_name = Str("selection_mode")

    # The name of the metadata on the datasource that we will set to a numpy
    # boolean array for masking the datasource's data
    mask_metadata_name = Str("selection_masks")

    # The possible event states of this selection tool (overrides
    # enable.Interactor).
    #
    # normal:
    #     Nothing has been selected, and the user is not dragging the mouse.
    # selecting:
    #     The user is dragging the mouse and actively changing the
    #     selection region; resizing of an existing selection also
    #     uses this mode.
    # selected:
    #     The user has released the mouse and a selection has been
    #     finalized.  The selection remains until the user left-clicks
    #     or self.deselect() is called.
    # moving:
    #   The user moving (not resizing) the selection range.
    event_state = Enum("normal", "selecting", "selected", "moving")

    #------------------------------------------------------------------------
    # Traits for overriding default object relationships
    #
    # By default, the RangeSelection assumes that self.component is a plot
    # and looks for the mapper and the axis_index on it.  If this is not the
    # case, then any (or all) three of these can be overriden by directly
    # assigning values to them.  To unset them and have them revert to default
    # behavior, assign "None" to them.
    #------------------------------------------------------------------------

    # The plot associated with this tool By default, this is just
    # self.component.
    plot = Property

    # The mapper for associated with this tool. By default, this is the mapper
    # on **plot** that corresponds to **axis**.
    mapper = Property

    # The index to use for **axis**. By default, this is self.plot.orientation,
    # but it can be overriden and set to 0 or 1.
    axis_index = Property

    # List of listeners that listen to selection events.
    listeners = List

    #------------------------------------------------------------------------
    # Configuring interaction control
    #------------------------------------------------------------------------

    # Can the user resize the selection once it has been drawn?
    enable_resize = Bool(True)

    # The pixel distance between the mouse event and a selection endpoint at
    # which the user action will be construed as a resize operation.
    resize_margin = Int(7)

    # Allow the left button begin a selection?
    left_button_selects = Bool(False)

    # Disable all left-mouse button interactions?
    disable_left_mouse = Bool(False)

    # Allow the tool to be put into the deselected state via mouse clicks
    allow_deselection = Bool(True)

    # The minimum span, in pixels, of a selection region.  Any attempt to
    # select a region smaller than this will be treated as a deselection.
    minimum_selection = Int(5)

    # The key which, if held down while the mouse is being dragged, will
    # indicate that the selection should be appended to an existing selection
    # as opposed to overwriting it.
    append_key = Instance(KeySpec, args=(None, "control"))

    #------------------------------------------------------------------------
    # Private traits
    #------------------------------------------------------------------------

    # The value of the override plot to use, if any.  If None, then uses
    # self.component.
    _plot = Trait(None, Any)

    # The value of the override mapper to use, if any.  If None, then uses the
    # mapper on self.component.
    _mapper = Trait(None, Any)

    # Shadow trait for the **axis_index** property.
    _axis_index = Trait(None, None, Int)

    # The data space start and end coordinates of the selected region,
    # expressed as a list.
    _selection = Trait(None, None, Tuple, List, Array)

    # The selection in mask form.
    _selection_mask = Array

    # The end of the selection that is being actively modified by the mouse.
    _drag_edge = Enum("high", "low")

    #------------------------------------------------------------------------
    # These record the mouse position when the user is moving (not resizing)
    # the selection
    #------------------------------------------------------------------------

    # The position of the initial user click for moving the selection.
    _down_point = Array  # (x,y)

    # The data space coordinates of **_down_point**.
    _down_data_coord = Float

    # The original selection when the mouse went down to move the selection.
    _original_selection = Any

    #------------------------------------------------------------------------
    # Public methods
    #------------------------------------------------------------------------

    def deselect(self, event=None):
        """ Deselects the highlighted region.

        This method essentially resets the tool. It takes the event causing the
        deselection as an optional argument.
        """
        self.selection = None
        self.selection_completed = None
        self.event_state = "normal"
        self.component.request_redraw()
        if event:
            event.window.set_pointer("arrow")
            event.handled = True
        return

    #------------------------------------------------------------------------
    # Event handlers for the "selected" event state
    #------------------------------------------------------------------------

    def selected_left_down(self, event):
        """ Handles the left mouse button being pressed when the tool is in
        the 'selected' state.

        If the user is allowed to resize the selection, and the event occurred
        within the resize margin of an endpoint, then the tool switches to the
        'selecting' state so that the user can resize the selection.

        If the event is within the bounds of the selection region, then the
        tool switches to the 'moving' states.

        Otherwise, the selection becomes deselected.
        """
        if self.disable_left_mouse:
            return

        screen_bounds = self._get_selection_screencoords()
        if screen_bounds is None:
            self.deselect(event)
            return
        low = min(screen_bounds)
        high = max(screen_bounds)
        tmp = (event.x, event.y)
        ndx = self.axis_index
        mouse_coord = tmp[ndx]

        if self.enable_resize:
            if (abs(mouse_coord - high) <= self.resize_margin) or \
                            (abs(mouse_coord - low) <= self.resize_margin):
                return self.selected_right_down(event)

        if low <= tmp[ndx] <= high:
            self.event_state = "moving"
            self._down_point = array([event.x, event.y])
            self._down_data_coord = \
                self.mapper.map_data(self._down_point)[ndx]
            self._original_selection = array(self.selection)
        elif self.allow_deselection:
            self.deselect(event)
        else:
            # Treat this as a combination deselect + left down
            self.deselect(event)
            self.normal_left_down(event)
        event.handled = True
        return

    def selected_right_down(self, event):
        """ Handles the right mouse button being pressed when the tool is in
        the 'selected' state.

        If the user is allowed to resize the selection, and the event occurred
        within the resize margin of an endpoint, then the tool switches to the
        'selecting' state so that the user can resize the selection.

        Otherwise, the selection becomes deselected, and a new selection is
        started..
        """
        if self.enable_resize:
            coords = self._get_selection_screencoords()
            if coords is not None:
                start, end = coords
                tmp = (event.x, event.y)
                ndx = self.axis_index
                mouse_coord = tmp[ndx]
                # We have to do a little swapping; the "end" point
                # is always what gets updated, so if the user
                # clicked on the starting point, we have to reverse
                # the sense of the selection.
                if abs(mouse_coord - end) <= self.resize_margin:
                    self.event_state = "selecting"
                    self._drag_edge = "high"
                    self.selecting_mouse_move(event)
                elif abs(mouse_coord - start) <= self.resize_margin:
                    self.event_state = "selecting"
                    self._drag_edge = "low"
                    self.selecting_mouse_move(event)
                #elif self.allow_deselection:
                #    self.deselect(event)
                else:
                    # Treat this as a combination deselect + right down
                    self.deselect(event)
                    self.normal_right_down(event)
        else:
            # Treat this as a combination deselect + right down
            self.deselect(event)
            self.normal_right_down(event)
        event.handled = True
        return

    def selected_mouse_move(self, event):
        """ Handles the mouse moving when the tool is in the 'selected' srate.

        If the user is allowed to resize the selection, and the event
        occurred within the resize margin of an endpoint, then the cursor
        changes to indicate that the selection could be resized.

        Otherwise, the cursor is set to an arrow.
        """
        if self.enable_resize:
            # Change the mouse cursor when the user moves within the
            # resize margin
            coords = self._get_selection_screencoords()
            if coords is not None:
                start, end = coords
                tmp = (event.x, event.y)
                ndx = self.axis_index
                mouse_coord = tmp[ndx]
                if abs(mouse_coord - end) <= self.resize_margin or \
                        abs(mouse_coord - start) <= self.resize_margin:
                    self._set_sizing_cursor(event)
                    return
        event.window.set_pointer("arrow")
        event.handled = True
        return

    def selected_mouse_leave(self, event):
        """ Handles the mouse leaving the plot when the tool is in the
        'selected' state.

        Sets the cursor to an arrow.
        """
        event.window.set_pointer("arrow")
        return

    #------------------------------------------------------------------------
    # Event handlers for the "moving" event state
    #------------------------------------------------------------------------

    def moving_left_up(self, event):
        """ Handles the left mouse button coming up when the tool is in the
        'moving' state.

        Switches the tool to the 'selected' state.
        """
        if self.disable_left_mouse:
            return

        self.event_state = "selected"
        self.selection_completed = self.selection
        self._down_point = []
        event.handled = True
        return

    def moving_mouse_move(self, event):
        """ Handles the mouse moving when the tool is in the 'moving' state.

        Moves the selection range by an amount corresponding to the amount
        that the mouse has moved since its button was pressed. If the new
        selection range overlaps the endpoints of the data, it is truncated to
        that endpoint.
        """
        cur_point = array([event.x, event.y])
        cur_data_point = self.mapper.map_data(cur_point)[self.axis_index]
        original_selection = self._original_selection
        new_selection = original_selection + (cur_data_point -
                                              self._down_data_coord)
        selection_data_width = original_selection[1] - original_selection[0]

        range = self.mapper.range
        if min(new_selection) < range.low:
            new_selection = (range.low, range.low + selection_data_width)
        elif max(new_selection) > range.high:
            new_selection = (range.high - selection_data_width, range.high)

        self.selection = new_selection
        self.selection_completed = new_selection
        self.component.request_redraw()
        event.handled = True
        return

    def moving_mouse_leave(self, event):
        """ Handles the mouse leaving the plot while the tool is in the
        'moving' state.

        If the mouse was within the selection region when it left, the method
        does nothing.

        If the mouse was outside the selection region whe it left, the event is
        treated as moving the selection to the minimum or maximum.
        """
        axis_index = self.axis_index
        low = self.plot.position[axis_index]
        high = low + self.plot.bounds[axis_index] - 1

        pos = self._get_axis_coord(event)
        if pos >= low and pos <= high:
            # the mouse left but was within the mapping range, so don't do
            # anything
            return
        else:
            # the mouse left and exceeds the mapping range, so we need to slam
            # the selection all the way to the minimum or the maximum
            self.moving_mouse_move(event)
        return

    def moving_mouse_enter(self, event):
        if not event.left_down:
            return self.moving_left_up(event)
        return

    #------------------------------------------------------------------------
    # Event handlers for the "normal" event state
    #------------------------------------------------------------------------

    def normal_left_down(self, event):
        """ Handles the left mouse button being pressed when the tool is in
        the 'normal' state.

        If the tool allows the left mouse button to start a selection, then
        it does so.
        """
        if self.left_button_selects:
            return self.normal_right_down(event)

    def normal_right_down(self, event):
        """ Handles the right mouse button being pressed when the tool is in
        the 'normal' state.

        Puts the tool into 'selecting' mode, changes the cursor to show that it
        is selecting, and starts defining the selection.

        """
        pos = self._get_axis_coord(event)
        mapped_pos = self.mapper.map_data(pos)
        self.selection = (mapped_pos, mapped_pos)
        self._set_sizing_cursor(event)
        self._down_point = array([event.x, event.y])
        self.event_state = "selecting"
        if self.append_key is not None and self.append_key.match(event):
            self.selection_mode = "append"
        else:
            self.selection_mode = "set"
        self.selecting_mouse_move(event)
        return

    #------------------------------------------------------------------------
    # Event handlers for the "selecting" event state
    #------------------------------------------------------------------------

    def selecting_mouse_move(self, event):
        """ Handles the mouse being moved when the tool is in the 'selecting'
        state.

        Expands the selection range at the appropriate end, based on the new
        mouse position.
        """
        if self.selection is not None:
            axis_index = self.axis_index
            low = self.plot.position[axis_index]
            high = low + self.plot.bounds[axis_index] - 1
            tmp = self._get_axis_coord(event)
            if tmp >= low and tmp <= high:
                new_edge = self.mapper.map_data(self._get_axis_coord(event))
                #new_edge = self._map_data(self._get_axis_coord(event))
                if self._drag_edge == "high":
                    low_val = self.selection[0]
                    if new_edge >= low_val:
                        self.selection = (low_val, new_edge)
                    else:
                        self.selection = (new_edge, low_val)
                        self._drag_edge = "low"
                else:
                    high_val = self.selection[1]
                    if new_edge <= high_val:
                        self.selection = (new_edge, high_val)
                    else:
                        self.selection = (high_val, new_edge)
                        self._drag_edge = "high"

                self.component.request_redraw()
            event.handled = True
        return

    def selecting_button_up(self, event):
        # Check to see if the selection region is bigger than the minimum
        event.window.set_pointer("arrow")

        end = self._get_axis_coord(event)

        if len(self._down_point) == 0:
            cancel_selection = False
        else:
            start = self._down_point[self.axis_index]
            self._down_point = []
            cancel_selection = self.minimum_selection > abs(start - end)

        if cancel_selection:
            self.deselect(event)
            event.handled = True
        else:
            self.event_state = "selected"

            # Fire the "completed" event
            self.selection_completed = self.selection
            event.handled = True
        return

    def selecting_right_up(self, event):
        """ Handles the right mouse button coming up when the tool is in the
        'selecting' state.

        Switches the tool to the 'selected' state and completes the selection.
        """
        self.selecting_button_up(event)

    def selecting_left_up(self, event):
        """ Handles the left mouse button coming up when the tool is in the
        'selecting' state.

        Switches the tool to the 'selected' state.
        """
        if self.disable_left_mouse:
            return
        self.selecting_button_up(event)

    def selecting_mouse_leave(self, event):
        """ Handles the mouse leaving the plot when the tool is in the
        'selecting' state.

        Determines whether the event's position is outside the component's
        bounds, and if so, clips the selection. Sets the cursor to an arrow.
        """
        axis_index = self.axis_index
        low = self.plot.position[axis_index]
        high = low + self.plot.bounds[axis_index] - 1

        old_selection = self.selection
        selection_low = old_selection[0]
        selection_high = old_selection[1]

        pos = self._get_axis_coord(event)
        if pos >= high:
            # clip to the boundary appropriate for the mapper's orientation.
            if self.mapper.sign == 1:
                selection_high = self.mapper.map_data(high)
            else:
                selection_high = self.mapper.map_data(low)
        elif pos <= low:
            if self.mapper.sign == 1:
                selection_low = self.mapper.map_data(low)
            else:
                selection_low = self.mapper.map_data(high)

        self.selection = (selection_low, selection_high)
        event.window.set_pointer("arrow")
        self.component.request_redraw()
        return

    def selecting_mouse_enter(self, event):
        """ Handles the mouse entering the plot when the tool is in the
        'selecting' state.

        If the mouse does not have the right mouse button down, this event
        is treated as if the right mouse button was released. Otherwise,
        the method sets the cursor to show that it is selecting.
        """
        # If we were in the "selecting" state when the mouse left, and
        # the mouse has entered without a button being down,
        # then treat this like we got a button up event.
        if not (event.right_down or event.left_down):
            return self.selecting_button_up(event)
        else:
            self._set_sizing_cursor(event)
        return

    #------------------------------------------------------------------------
    # Property getter/setters
    #------------------------------------------------------------------------

    def _get_plot(self):
        if self._plot is not None:
            return self._plot
        else:
            return self.component

    def _set_plot(self, val):
        self._plot = val
        return

    def _get_mapper(self):
        if self._mapper is not None:
            return self._mapper
        else:
            return getattr(self.plot, self.axis + "_mapper")

    def _set_mapper(self, new_mapper):
        self._mapper = new_mapper
        return

    def _get_axis_index(self):
        if self._axis_index is None:
            return self._determine_axis()
        else:
            return self._axis_index

    def _set_axis_index(self, val):
        self._axis_index = val
        return

    def _get_selection(self):
        selection = getattr(self.plot, self.axis).metadata[self.metadata_name]
        return selection

    def _set_selection(self, val):
        oldval = self._selection
        self._selection = val

        datasource = getattr(self.plot, self.axis, None)

        if datasource is not None:

            mdname = self.metadata_name

            # Set the selection range on the datasource
            datasource.metadata[mdname] = val
            datasource.metadata_changed = {mdname: val}

            # Set the selection mask on the datasource
            selection_masks = \
                datasource.metadata.setdefault(self.mask_metadata_name, [])
            for index in range(len(selection_masks)):
                if id(selection_masks[index]) == id(self._selection_mask):
                    del selection_masks[index]
                    break

            # Set the selection mode on the datasource
            datasource.metadata[self.selection_mode_metadata_name] = \
                      self.selection_mode

            if val is not None:
                low, high = val
                data_pts = datasource.get_data()
                new_mask = (data_pts >= low) & (data_pts <= high)
                selection_masks.append(new_mask)
                self._selection_mask = new_mask
            datasource.metadata_changed = {self.mask_metadata_name: val}

        self.trait_property_changed("selection", oldval, val)

        for l in self.listeners:
            if hasattr(l, "set_value_selection"):
                l.set_value_selection(val)

        return

    #------------------------------------------------------------------------
    # Private methods
    #------------------------------------------------------------------------

    def _get_selection_screencoords(self):
        """ Returns a tuple of (x1, x2) screen space coordinates of the start
        and end selection points.

        If there is no current selection, then it returns None.
        """
        selection = self.selection
        if selection is not None and len(selection) == 2:
            return self.mapper.map_screen(array(selection))
        else:
            return None

    def _set_sizing_cursor(self, event):
        """ Sets the correct cursor shape on the window of the event, given the
        tool's orientation and axis.
        """
        if self.axis_index == 0:
            # horizontal range selection, so use left/right arrow
            event.window.set_pointer("size left")
        else:
            # vertical range selection, so use up/down arrow
            event.window.set_pointer("size top")
        return

    def _get_axis_coord(self, event, axis="index"):
        """ Returns the coordinate of the event along the axis of interest
        to this tool (or along the orthogonal axis, if axis="value").
        """
        event_pos = (event.x, event.y)
        if axis == "index":
            return event_pos[self.axis_index]
        else:
            return event_pos[1 - self.axis_index]

    def _determine_axis(self):
        """ Determines whether the index of the coordinate along this tool's
        axis of interest is the first or second element of an (x,y) coordinate
        tuple.

        This method is only called if self._axis_index hasn't been set (or is
        None).
        """
        if self.axis == "index":
            if self.plot.orientation == "h":
                return 0
            else:
                return 1
        else:  # self.axis == "value"
            if self.plot.orientation == "h":
                return 1
            else:
                return 0

    def __mapper_changed(self):
        self.deselect()
        return

    def _axis_changed(self, old, new):
        if old is not None:
            self.plot.on_trait_change(self.__mapper_changed,
                                      old + "_mapper",
                                      remove=True)
        if new is not None:
            self.plot.on_trait_change(self.__mapper_changed,
                                      old + "_mapper",
                                      remove=True)
        return
Esempio n. 7
0
class TabularEditor(BasicEditorFactory):
    """ Editor factory for tabular editors.
    """

    # -- Trait Definitions ----------------------------------------------------

    #: The editor class to be created:
    klass = Property()

    #: Should column headers (i.e. titles) be displayed?
    show_titles = Bool(True)

    #: Should row headers be displayed (Qt4 only)?
    show_row_titles = Bool(False)

    #: The optional extended name of the trait used to indicate that a complete
    #: table update is needed:
    update = Str()

    #: The optional extended name of the trait used to indicate that the table
    #: just needs to be repainted.
    refresh = Str()

    #: Should the table update automatically when the table item's contents
    #: change? Note that in order for this feature to work correctly, the
    #: editor trait should be a list of objects derived from HasTraits. Also,
    #: performance can be affected when very long lists are used, since
    #: enabling this feature adds and removed Traits listeners to each item in
    #: the list.
    auto_update = Bool(False)

    #: The optional extended name of the trait to synchronize the selection
    #: values with:
    selected = Str()

    #: The optional extended name of the trait to synchronize the selection
    #: rows with:
    selected_row = Str()

    #: Whether or not to allow selection.
    selectable = Bool(True)

    #: The optional extended name of the trait to synchronize the activated
    #: value with:
    activated = Str()

    #: The optional extended name of the trait to synchronize the activated
    #: value's row with:
    activated_row = Str()

    #: The optional extended name of the trait to synchronize left click data
    #: with. The data is a TabularEditorEvent:
    clicked = Str()

    #: The optional extended name of the trait to synchronize left double click
    #: data with. The data is a TabularEditorEvent:
    dclicked = Str()

    #: The optional extended name of the trait to synchronize right click data
    #: with. The data is a TabularEditorEvent:
    right_clicked = Str()

    #: The optional extended name of the trait to synchronize right double
    #: clicked data with. The data is a TabularEditorEvent:
    right_dclicked = Str()

    #: The optional extended name of the trait to synchronize column
    #: clicked data with. The data is a TabularEditorEvent:
    column_clicked = Str()

    #: The optional extended name of the trait to synchronize column
    #: right clicked data with. The data is a TabularEditorEvent:
    column_right_clicked = Str()

    #: The optional extended name of the Event trait that should be used to
    #: trigger a scroll-to command. The data is an integer giving the row.
    scroll_to_row = Str()

    #: The optional extended name of the Event trait that should be used to
    #: trigger a scroll-to command. The data is an integer giving the column.
    scroll_to_column = Str()

    #: Deprecated: Controls behavior of scroll to row and scroll to column
    scroll_to_row_hint = Property(Str, observe="scroll_to_position_hint")

    #: (replacement of scroll_to_row_hint, but more clearly named)
    #: Controls behavior of scroll to row and scroll to column
    scroll_to_position_hint = Enum("visible", "center", "top", "bottom")

    #: Can the user edit the values?
    editable = Bool(True)

    #: Can the user edit the labels (i.e. the first column)
    editable_labels = Bool(False)

    #: Are multiple selected items allowed?
    multi_select = Bool(False)

    #: Should horizontal lines be drawn between items?
    horizontal_lines = Bool(True)

    #: Should vertical lines be drawn between items?
    vertical_lines = Bool(True)

    #: Should the columns automatically resize? Don't allow this when the
    #: amount of data is large.
    auto_resize = Bool(False)

    #: Should the rows automatically resize (Qt4 only)? Don't allow
    #: this when the amount of data is large.
    auto_resize_rows = Bool(False)

    #: Whether to stretch the last column to fit the available space.
    stretch_last_section = Bool(True)

    #: The adapter from trait values to editor values:
    adapter = Instance("traitsui.tabular_adapter.TabularAdapter", ())

    #: What type of operations are allowed on the list:
    operations = List(
        Enum("delete", "insert", "append", "edit", "move"),
        ["delete", "insert", "append", "edit", "move"],
    )

    #: Are 'drag_move' operations allowed (i.e. True), or should they always be
    #: treated as 'drag_copy' operations (i.e. False):
    drag_move = Bool(True)

    #: The set of images that can be used:
    images = List(Image)

    def _get_klass(self):
        """ Returns the toolkit-specific editor class to be instantiated.
        """
        return toolkit_object("tabular_editor:TabularEditor")

    def _get_scroll_to_row_hint(self):
        warnings.warn(
            "Use of scroll_to_row_hint trait is deprecated. "
            "Use scroll_to_position_hint instead.",
            DeprecationWarning,
        )
        return self.scroll_to_position_hint

    def _set_scroll_to_row_hint(self, hint):
        warnings.warn(
            "Use of scroll_to_row_hint trait is deprecated. "
            "Use scroll_to_position_hint instead.",
            DeprecationWarning,
        )
        self.scroll_to_position_hint = hint
Esempio n. 8
0
class TableColumn(HasPrivateTraits):
    """Represents a column in a table editor."""

    # -------------------------------------------------------------------------
    #  Trait definitions:
    # -------------------------------------------------------------------------

    #: Column label to use for this column:
    label = Str(UndefinedLabel)

    #: Type of data contained by the column:
    # XXX currently no other types supported, but potentially there could be...
    type = Enum("text", "bool")

    #: Text color for this column:
    text_color = Color("black")

    #: Text font for this column:
    text_font = Union(None, Font)

    #: Cell background color for this column:
    cell_color = Color("white", allow_none=True)

    #: Cell background color for non-editable columns:
    read_only_cell_color = Color(0xF4F3EE, allow_none=True)

    #: Cell graph color:
    graph_color = Color(0xDDD9CC)

    #: Horizontal alignment of text in the column:
    horizontal_alignment = Enum("left", ["left", "center", "right"])

    #: Vertical alignment of text in the column:
    vertical_alignment = Enum("center", ["top", "center", "bottom"])

    #: Horizontal cell margin
    horizontal_margin = Int(4)

    #: Vertical cell margin
    vertical_margin = Int(3)

    #: The image to display in the cell:
    image = Image

    #: Renderer used to render the contents of this column:
    renderer = Any  # A toolkit specific renderer

    #: Is the table column visible (i.e., viewable)?
    visible = Bool(True)

    #: Is this column editable?
    editable = Bool(True)

    #: Is the column automatically edited/viewed (i.e. should the column editor
    #: or popup be activated automatically on mouse over)?
    auto_editable = Bool(False)

    #: Should a checkbox be displayed instead of True/False?
    show_checkbox = Bool(True)

    #: Can external objects be dropped on the column?
    droppable = Bool(False)

    #: Context menu to display when this column is right-clicked:
    menu = Instance(Menu)

    #: The tooltip to display when the mouse is over the column:
    tooltip = Str()

    #: The width of the column (< 0.0: Default, 0.0..1.0: fraction of total
    #: table width, > 1.0: absolute width in pixels):
    width = Float(-1.0)

    #: The width of the column while it is being edited (< 0.0: Default,
    #: 0.0..1.0: fraction of total table width, > 1.0: absolute width in
    #: pixels):
    edit_width = Float(-1.0)

    #: The height of the column cell's row while it is being edited
    #: (< 0.0: Default, 0.0..1.0: fraction of total table height,
    #: > 1.0: absolute height in pixels):
    edit_height = Float(-1.0)

    #: The resize mode for this column.  This takes precedence over other
    #: settings (like **width**, above).
    #: - "interactive": column can be resized by users or programmatically
    #: - "fixed": users cannot resize the column, but it can be set programmatically
    #: - "stretch": the column will be resized to fill the available space
    #: - "resize_to_contents": column will be sized to fit the contents, but then cannot be resized
    resize_mode = Enum("interactive", "fixed", "stretch", "resize_to_contents")

    #: The view (if any) to display when clicking a non-editable cell:
    view = AView

    #: Optional maximum value a numeric cell value can have:
    maximum = Float(trait_value=True)

    # -------------------------------------------------------------------------
    #:  Returns the actual object being edited:
    # -------------------------------------------------------------------------

    def get_object(self, object):
        """Returns the actual object being edited."""
        return object

    def get_label(self):
        """Gets the label of the column."""
        return self.label

    def get_width(self):
        """Returns the width of the column."""
        return self.width

    def get_edit_width(self, object):
        """Returns the edit width of the column."""
        return self.edit_width

    def get_edit_height(self, object):
        """Returns the height of the column cell's row while it is being
        edited.
        """
        return self.edit_height

    def get_type(self, object):
        """Gets the type of data for the column for a specified object."""
        return self.type

    def get_text_color(self, object):
        """Returns the text color for the column for a specified object."""
        return self.text_color_

    def get_text_font(self, object):
        """Returns the text font for the column for a specified object."""
        return self.text_font

    def get_cell_color(self, object):
        """Returns the cell background color for the column for a specified
        object.
        """
        if self.is_editable(object):
            return self.cell_color_
        return self.read_only_cell_color_

    def get_graph_color(self, object):
        """Returns the cell background graph color for the column for a
        specified object.
        """
        return self.graph_color_

    def get_horizontal_alignment(self, object):
        """Returns the horizontal alignment for the column for a specified
        object.
        """
        return self.horizontal_alignment

    def get_vertical_alignment(self, object):
        """Returns the vertical alignment for the column for a specified
        object.
        """
        return self.vertical_alignment

    def get_image(self, object):
        """Returns the image to display for the column for a specified object."""
        return self.image

    def get_renderer(self, object):
        """Returns the renderer for the column of a specified object."""
        return self.renderer

    def is_editable(self, object):
        """Returns whether the column is editable for a specified object."""
        return self.editable

    def is_auto_editable(self, object):
        """Returns whether the column is automatically edited/viewed for a
        specified object.
        """
        return self.auto_editable

    def is_droppable(self, object, value):
        """Returns whether a specified value is valid for dropping on the
        column for a specified object.
        """
        return self.droppable

    def get_menu(self, object):
        """Returns the context menu to display when the user right-clicks on
        the column for a specified object.
        """
        return self.menu

    def get_tooltip(self, object):
        """Returns the tooltip to display when the user mouses over the column
        for a specified object.
        """
        return self.tooltip

    def get_view(self, object):
        """Returns the view to display when clicking a non-editable cell."""
        return self.view

    def get_maximum(self, object):
        """Returns the maximum value a numeric column can have."""
        return self.maximum

    def on_click(self, object):
        """Called when the user clicks on the column."""
        pass

    def on_dclick(self, object):
        """Called when the user clicks on the column."""
        pass

    def cmp(self, object1, object2):
        """Returns the result of comparing the column of two different objects.

        This is deprecated.
        """
        return (self.key(object1) > self.key(object2)) - (self.key(object1) <
                                                          self.key(object2))

    def __str__(self):
        """Returns the string representation of the table column."""
        return self.get_label()
Esempio n. 9
0
"""
 Defines the class that describes the information on the inputs and
 outputs of an object in the pipeline.
"""
# Author: Prabhu Ramachandran <*****@*****.**>
# Copyright (c) 2008-2020, Prabhu Ramachandran Enthought, Inc.
# License: BSD Style.

# Enthought library imports.
from traits.api import HasTraits, Enum, List

from .utils import get_tvtk_dataset_name

# The list of datasets supported.
DataSet = Enum('none', 'any', 'image_data', 'rectilinear_grid', 'poly_data',
               'structured_grid', 'unstructured_grid')

# Attribute type.
AttributeType = Enum('any', 'cell', 'point', 'none')

# Attribute.
Attribute = Enum('any', 'none', 'scalars', 'vectors', 'tensors')


################################################################################
# `PipelineInfo` class.
################################################################################
class PipelineInfo(HasTraits):
    """
    This class represents the information that a particular input or
    output of an object should contain.
Esempio n. 10
0
class MATS2DMicroplaneDamage(MATSXDMicroplaneDamage, MATS2DEval):

    # implements(IMATSEval)

    # number of spatial dimensions
    #
    n_dim = Constant(2)

    # number of components of engineering tensor representation
    #
    n_eng = Constant(3)

    # planar constraint
    stress_state = Enum("plane_strain", "plane_stress")

    # Specify the class to use for directional dependence
    mfn_class = Type(MFnPolar)

    # get the normal vectors of the microplanes
    _MPN = Property(depends_on='n_mp')

    @cached_property
    def _get__MPN(self):
        return array([[cos(alpha), sin(alpha)] for alpha in self.alpha_list])

    # get the weights of the microplanes
    _MPW = Property(depends_on='n_mp')

    @cached_property
    def _get__MPW(self):
        return ones(self.n_mp) / self.n_mp * 2

    elasticity_tensors = Property(depends_on='E, nu, stress_state')

    @cached_property
    def _get_elasticity_tensors(self):
        '''
        Intialize the fourth order elasticity tensor
        for 3D or 2D plane strain or 2D plane stress
        '''
        # ----------------------------------------------------------------------------
        # Lame constants calculated from E and nu
        # ----------------------------------------------------------------------------
        E = self.E
        nu = self.nu

        # first Lame paramter
        la = E * nu / ((1 + nu) * (1 - 2 * nu))
        # second Lame parameter (shear modulus)
        mu = E / (2 + 2 * nu)

        # -----------------------------------------------------------------------------------------------------
        # Get the fourth order elasticity and compliance tensors for the 3D-case
        # -----------------------------------------------------------------------------------------------------

        # The following lines correspond to the tensorial expression:
        # (using numpy functionality in order to avoid the loop):
        #
        # D4_e_3D = zeros((3,3,3,3),dtype=float)
        # C4_e_3D = zeros((3,3,3,3),dtype=float)
        # delta = identity(3)
        # for i in range(0,3):
        #     for j in range(0,3):
        #         for k in range(0,3):
        #             for l in range(0,3):
        #                 # elasticity tensor (cf. Jir/Baz Inelastic analysis of structures Eq.D25):
        #                 D4_e_3D[i,j,k,l] = la * delta[i,j] * delta[k,l] + \
        #                                    mu * ( delta[i,k] * delta[j,l] + delta[i,l] * delta[j,k] )
        #                 # elastic compliance tensor (cf. Simo, Computational Inelasticity, Eq.(2.7.16) AND (2.1.16)):
        #                 C4_e_3D[i,j,k,l] = (1+nu)/(2*E) * \
        #                                    ( delta[i,k] * delta[j,l] + delta[i,l]* delta[j,k] ) - \
        #                                    nu / E * delta[i,j] * delta[k,l]
        #
        # NOTE: swapaxes returns a reference not a copy!
        # (the index notation always refers to the initial indexing (i=0,j=1,k=2,l=3))
        delta = identity(3)
        delta_ijkl = outer(delta, delta).reshape(3, 3, 3, 3)
        delta_ikjl = delta_ijkl.swapaxes(1, 2)
        delta_iljk = delta_ikjl.swapaxes(2, 3)
        D4_e_3D = la * delta_ijkl + mu * (delta_ikjl + delta_iljk)
        C4_e_3D = -nu / E * delta_ijkl + \
            (1 + nu) / (2 * E) * (delta_ikjl + delta_iljk)

        # -----------------------------------------------------------------------------------------------------
        # Get the fourth order elasticity and compliance tensors for the 2D-case
        # -----------------------------------------------------------------------------------------------------
        # 1. step: Get the (6x6)-elasticity and compliance matrices
        #          for the 3D-case:
        D2_e_3D = map3d_tns4_to_tns2(D4_e_3D)
        C2_e_3D = map3d_tns4_to_tns2(C4_e_3D)

        # 2. step: Get the (3x3)-elasticity and compliance matrices
        #          for the 2D-cases plane stress and plane strain:
        D2_e_2D_plane_stress = get_D_plane_stress(D2_e_3D)
        D2_e_2D_plane_strain = get_D_plane_strain(D2_e_3D)
        C2_e_2D_plane_stress = get_C_plane_stress(C2_e_3D)
        C2_e_2D_plane_strain = get_C_plane_strain(C2_e_3D)

        if self.stress_state == 'plane_stress':
            D2_e = D2_e_2D_plane_stress

        if self.stress_state == 'plane_strain':
            D2_e = D2_e_2D_plane_strain

        # 3. step: Get the fourth order elasticity and compliance tensors
        # for the 2D-cases plane stress and plane strain (D4.shape = (2,2,2,2))
        D4_e_2D_plane_stress = map2d_tns2_to_tns4(D2_e_2D_plane_stress)
        D4_e_2D_plane_strain = map2d_tns2_to_tns4(D2_e_2D_plane_strain)
        C4_e_2D_plane_stress = map2d_tns2_to_tns4(C2_e_2D_plane_stress)
        C4_e_2D_plane_strain = map2d_tns2_to_tns4(C2_e_2D_plane_strain)

        # -----------------------------------------------------------------------------------------------------
        # assign the fourth order elasticity and compliance tensors as return values
        # -----------------------------------------------------------------------------------------------------
        if self.stress_state == 'plane_stress':
            # print 'stress state:   plane-stress'
            D4_e = D4_e_2D_plane_stress
            C4_e = C4_e_2D_plane_stress

        if self.stress_state == 'plane_strain':
            # print 'stress state:   plane-strain'
            D4_e = D4_e_2D_plane_strain
            C4_e = C4_e_2D_plane_strain

        return D4_e, C4_e, D2_e

    def _get_explorer_config(self):
        '''Get the specific configuration of this material model in the explorer
        '''
        c = super(MATS2DMicroplaneDamage, self)._get_explorer_config()

        from ibvpy.tmodel.mats2D.mats2D_rtrace_cylinder import MATS2DRTraceCylinder

        # overload the default configuration
        c['rtrace_list'] += [
            MATS2DRTraceCylinder(name='Laterne',
                                 var_axis='time',
                                 idx_axis=0,
                                 var_surface='microplane_damage',
                                 record_on='update'),
        ]

        return c

    #-------------------------------------------------------------------------
    # Dock-based view with its own id
    #-------------------------------------------------------------------------
    traits_view = View(Include('polar_fn_group'),
                       dock='tab',
                       id='ibvpy.tmodel.mats3D.mats_2D_cmdm.MATS2D_cmdm',
                       kind='modal',
                       resizable=True,
                       scrollable=True,
                       width=0.6,
                       height=0.8,
                       buttons=['OK', 'Cancel'])
Esempio n. 11
0
# The valid categories of imported elements that can be dragged into the view:
ImportTypes = List(Str,
                   desc='the categories of elements that can be '
                   'dragged into the view')

# The view position and size traits:
Width = Float(-1E6, desc='the width of the view window')
Height = Float(-1E6, desc='the height of the view window')
XCoordinate = Float(-1E6, desc='the x coordinate of the view window')
YCoordinate = Float(-1E6, desc='the y coordinate of the view window')

# The result that should be returned if the user clicks the window or dialog
# close button or icon
CloseResult = Enum(None,
                   True,
                   False,
                   desc='the result to return when the user clicks the '
                   'window or dialog close button or icon')

# The KeyBindings trait:
AKeyBindings = Instance('traitsui.key_bindings.KeyBindings',
                        desc='the global key bindings for the view')

#-------------------------------------------------------------------------
#  'View' class:
#-------------------------------------------------------------------------


class View(ViewElement):
    """ A Traits-based user interface for one or more objects.
Esempio n. 12
0
class SaveTool(BaseTool):
    """ This tool allows the user to press Ctrl+S to save a snapshot image of
    the plot component.
    """

    # The file that the image is saved in.  The format will be deduced from
    # the extension.
    filename = Str("saved_plot.png")

    #-------------------------------------------------------------------------
    # PDF format options
    # This mirror the traits in PdfPlotGraphicsContext.
    #-------------------------------------------------------------------------

    pagesize = Enum("letter", "A4")
    dest_box = Tuple((0.5, 0.5, -0.5, -0.5))
    dest_box_units = Enum("inch", "cm", "mm", "pica")

    #-------------------------------------------------------------------------
    # Override default trait values inherited from BaseTool
    #-------------------------------------------------------------------------

    # This tool does not have a visual representation (overrides BaseTool).
    draw_mode = "none"

    # This tool is not visible (overrides BaseTool).
    visible = False

    def normal_key_pressed(self, event):
        """ Handles a key-press when the tool is in the 'normal' state.

        Saves an image of the plot if the keys pressed are Control and S.
        """
        if self.component is None:
            return

        if event.character == "s" and event.control_down:
            if os.path.splitext(self.filename)[-1] == ".pdf":
                self._save_pdf()
            else:
                self._save_raster()
            event.handled = True
        return

    def _save_raster(self):
        """ Saves an image of the component.
        """
        from chaco.api import PlotGraphicsContext
        gc = PlotGraphicsContext((int(self.component.outer_width),
                                  int(self.component.outer_height)))
        self.component.draw(gc, mode="normal")
        gc.save(self.filename)
        return

    def _save_pdf(self):
        from chaco.pdf_graphics_context import PdfPlotGraphicsContext
        gc = PdfPlotGraphicsContext(filename=self.filename,
                                    pagesize=self.pagesize,
                                    dest_box=self.dest_box,
                                    dest_box_units=self.dest_box_units)
        gc.render_component(self.component)
        gc.save()
Esempio n. 13
0
class Controls(HasTraits):
    if len(inputs) == 1:
        default_input = inputs

    for i in inputs:
        if not "Through Port" in i[1]:
            default_input = i
            break

    default_input = default_input if inputs else None

    default_output = -1
    through_port_output = None
    for i in outputs:
        if not "Through Port" in i[1]:
            default_output = i
            break
        else:
            through_port_output = i
    default_output = default_output if len(
        outputs) > 1 else through_port_output

    if default_input is None or default_output is None:
        print('Cannot connect to any MIDI device')

    input_device = List(value=default_input,
                        editor=CheckListEditor(values=inputs))
    output_device = List(value=default_output,
                         editor=CheckListEditor(values=outputs))

    max_temp = 2.
    min_temp = 0.5
    max_press = 10.
    min_press = 5e-4
    max_vol = 100000.
    min_vol = 50.
    max_n = 1000
    min_n = 50

    temperature = Range(
        min_temp,
        max_temp,
        1.,
    )
    volume = Float(box_l**3.)
    pressure = Float(1.)
    number_of_particles = Range(
        min_n,
        max_n,
        n_part,
    )
    ensemble = Enum('NVT', 'NPT')

    midi_input = None
    midi_output = None

    MIDI_BASE = 224
    MIDI_NUM_TEMPERATURE = MIDI_BASE + 0
    MIDI_NUM_VOLUME = MIDI_BASE + 1
    MIDI_NUM_PRESSURE = MIDI_BASE + 2
    MIDI_NUM_NUMBEROFPARTICLES = MIDI_BASE + 3

    MIDI_ROTATE = 0

    MIDI_ZOOM = 144

    _ui = Any
    view = View(Group(Item('temperature',
                           editor=RangeEditor(low_name='min_temp',
                                              high_name='max_temp')),
                      Item('volume',
                           editor=RangeEditor(low_name='min_vol',
                                              high_name='max_vol')),
                      Item('pressure',
                           editor=RangeEditor(low_name='min_press',
                                              high_name='max_press')),
                      Item('number_of_particles',
                           editor=RangeEditor(low_name='min_n',
                                              high_name='max_n',
                                              is_float=False)),
                      Item('ensemble', style='custom'),
                      show_labels=True,
                      label='Parameters'),
                Group(Item('input_device'),
                      Item('output_device'),
                      show_labels=True,
                      label='MIDI devices'),
                buttons=[],
                title='Control',
                height=0.2,
                width=0.3)

    def __init__(self, **traits):
        super(Controls, self).__init__(**traits)
        self._ui = self.edit_traits()
        self.push_current_values()

    def push_current_values(self):
        """send the current values to the MIDI controller"""
        self._temperature_fired()
        self._volume_fired()
        self._pressure_fired()
        self._number_of_particles_fired()
        self._ensemble_fired()

    def _input_device_fired(self):
        if self.midi_input is not None:
            self.midi_input.close()
        if self.input_device:
            self.midi_input = midi.Input(self.input_device[0])

    def _output_device_fired(self):
        if self.midi_output is not None:
            self.midi_output.close()
        self.midi_output = midi.Output(self.output_device[0])
        self.push_current_values()

    def _temperature_fired(self):
        status = self.MIDI_NUM_TEMPERATURE
        data1 = int((self.temperature - self.min_temp) /
                    (self.max_temp - self.min_temp) * 127)
        data2 = data1
        if self.midi_output is not None:
            self.midi_output.write_short(status, data1, data2)

    def _volume_fired(self):
        status = self.MIDI_NUM_VOLUME
        data1 = limit_range(int((system.box_l[0]**3. - self.min_vol) /
                                (self.max_vol - self.min_vol) * 127),
                            minval=0,
                            maxval=127)
        data2 = data1

        if self.midi_output is not None:
            self.midi_output.write_short(status, data1, data2)

    def _pressure_fired(self):
        status = self.MIDI_NUM_PRESSURE

        if pressure_log_flag:
            data1 = limit_range(
                int(127 * (np.log(self.pressure) - np.log(self.min_press)) /
                    (np.log(self.max_press) - np.log(self.min_press))),
                minval=0,
                maxval=127)
        else:
            data1 = limit_range(int((self.pressure - self.min_press) /
                                    (self.max_press - self.min_press) * 127),
                                minval=0,
                                maxval=127)
        data2 = data1
        if self.midi_output is not None:
            self.midi_output.write_short(status, data1, data2)

    def _number_of_particles_fired(self):
        status = self.MIDI_NUM_NUMBEROFPARTICLES
        data1 = int(self.number_of_particles / self.max_n * 127)
        data2 = data1
        if self.midi_output is not None:
            self.midi_output.write_short(status, data1, data2)

    def _ensemble_fired(self):
        if self.midi_output is not None:
            self.midi_output.write_short(144, 0, 127)  # T
            self.midi_output.write_short(144, 1,
                                         127 * (self.ensemble != 'NPT'))  # V
            self.midi_output.write_short(144, 2,
                                         127 * (self.ensemble == 'NPT'))  # P
            self.midi_output.write_short(144, 3, 127)  # N
Esempio n. 14
0
class LUTManager(Base):

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The lookup table.
    lut = Instance(tvtk.LookupTable, (), record=False)
    # The scalar bar.
    scalar_bar = Instance(tvtk.ScalarBarActor, (), record=True)
    # The scalar_bar_widget
    scalar_bar_widget = Instance(tvtk.ScalarBarWidget, ())

    # The representation associated with the scalar_bar_widget.  This
    # only exists in VTK versions about around 5.2.
    scalar_bar_representation = Instance(tvtk.Object, allow_none=True,
                                         record=True)

    # The title text property of the axes.
    title_text_property = Property(record=True)

    # The label text property of the axes.
    label_text_property = Property(record=True)

    # The current mode of the LUT.
    lut_mode = Enum('blue-red', lut_mode_list(),
                     desc='the type of the lookup table')

    # File name of the LUT file to use.
    file_name = Str('', editor=FileEditor,
                    desc='the filename containing the LUT')

    # Reverse the colors of the LUT.
    reverse_lut = Bool(False, desc='if the lut is to be reversed')

    # Turn on/off the visibility of the scalar bar.
    show_scalar_bar = Bool(False,
                           desc='if scalar bar is shown or not')

    # This is an alias for show_scalar_bar.
    show_legend = Property(Bool, desc='if legend is shown or not')

    # The number of labels to use for the scalar bar.
    number_of_labels = Range(0, 64, 8, enter_set=True, auto_set=False,
                             desc='the number of labels to display')

    # Number of colors for the LUT.
    number_of_colors = Range(2, 2147483647, 256, enter_set=True,
                             auto_set=False,
                             desc='the number of colors for the LUT')

    # Enable shadowing of the labels and text.
    shadow = Bool(False, desc='if the labels and text have shadows')

    # Use the default data name or the user specified one.
    use_default_name = Bool(True,
                            desc='if the default data name is to be used')

    # The default data name -- set by the module manager.
    default_data_name = Str('data', enter_set=True, auto_set=False,
                            desc='the default data name')

    # The optionally user specified name of the data.
    data_name = Str('', enter_set=True, auto_set=False,
                    desc='the title of the legend')

    # Use the default range or user specified one.
    use_default_range = Bool(True,
                             desc='if the default data range is to be used')
    # The default data range -- this is computed and set by the
    # module manager.
    default_data_range = Array(shape=(2,), value=[0.0, 1.0],
                               dtype=float, enter_set=True, auto_set=False,
                               desc='the default range of the data mapped')

    # The optionally user defined range of the data.
    data_range = Array(shape=(2,), value=[0.0, 1.0],
                       dtype=float, enter_set=True, auto_set=False,
                       desc='the range of the data mapped')

    # Create a new LUT.
    create_lut = Button('Launch LUT editor',
                        desc='if we launch a Lookup table editor in'
                             ' a separate process')

    ########################################
    ## Private traits.
    # The original range of the data.
    _orig_data_range = Array(shape=(2,), value=[0.0, 1.0], dtype=float)
    _title_text_property = Instance(tvtk.TextProperty)
    _label_text_property = Instance(tvtk.TextProperty)

    ######################################################################
    # `object` interface
    ######################################################################
    def __init__(self, **traits):
        super(LUTManager, self).__init__(**traits)

        # Initialize the scalar bar.
        sc_bar = self.scalar_bar
        sc_bar.set(lookup_table=self.lut,
                   title=self.data_name,
                   number_of_labels=self.number_of_labels,
                   orientation='horizontal',
                   width=0.8, height=0.17)
        pc = sc_bar.position_coordinate
        pc.set(coordinate_system='normalized_viewport',
               value=(0.1, 0.01, 0.0))
        self._shadow_changed(self.shadow)

        # Initialize the lut.
        self._lut_mode_changed(self.lut_mode)

        # Set the private traits.
        ttp = self._title_text_property = sc_bar.title_text_property
        ltp = self._label_text_property = sc_bar.label_text_property

        # Call render when the text properties are changed.
        ttp.on_trait_change(self.render)
        ltp.on_trait_change(self.render)

        # Initialize the scalar_bar_widget
        self.scalar_bar_widget.set(scalar_bar_actor=self.scalar_bar,
                                   key_press_activation=False)
        self._number_of_colors_changed(self.number_of_colors)


    ######################################################################
    # `Base` interface
    ######################################################################
    def start(self):
        """This is invoked when this object is added to the mayavi
        pipeline.
        """
        # Do nothing if we are already running.
        if self.running:
            return

        # Show the legend if necessary.
        self._show_scalar_bar_changed(self.show_scalar_bar)

        # Call parent method to set the running state.
        super(LUTManager, self).start()

    def stop(self):
        """Invoked when this object is removed from the mayavi
        pipeline.
        """
        if not self.running:
            return

        # Hide the scalar bar.
        sbw = self.scalar_bar_widget
        if sbw.interactor is not None:
            sbw.off()

        # Call parent method to set the running state.
        super(LUTManager, self).stop()

    ######################################################################
    # Non-public interface
    ######################################################################
    def _lut_mode_changed(self, value):

        if value == 'file':
            if self.file_name:
                self.load_lut_from_file(self.file_name)
            #self.lut.force_build()
            return

        reverse = self.reverse_lut
        if value in pylab_luts:
            lut = pylab_luts[value]
            if reverse:
                lut = lut[::-1, :]
            n_total = len(lut)
            n_color = self.number_of_colors
            if not n_color >= n_total:
                lut = lut[::round(n_total/float(n_color))]
            self.load_lut_from_list(lut.tolist())
            #self.lut.force_build()
            return
        elif value == 'blue-red':
            if reverse:
                hue_range = 0.0, 0.6667
                saturation_range = 1.0, 1.0
                value_range = 1.0, 1.0
            else:
                hue_range = 0.6667, 0.0
                saturation_range = 1.0, 1.0
                value_range = 1.0, 1.0
        elif value == 'black-white':
            if reverse:
                hue_range = 0.0, 0.0
                saturation_range = 0.0, 0.0
                value_range = 1.0, 0.0
            else:
                hue_range = 0.0, 0.0
                saturation_range = 0.0, 0.0
                value_range = 0.0, 1.0
        lut = self.lut
        lut.set(hue_range=hue_range, saturation_range=saturation_range,
                value_range=value_range,
                number_of_table_values=self.number_of_colors,
                ramp='sqrt')
        lut.modified()
        lut.force_build()

        self.render()

    def _scene_changed(self, value):
        sbw = self.scalar_bar_widget
        if value is None:
            return
        if sbw.interactor is not None:
            sbw.off()
        value.add_widgets(sbw, enabled=False)
        if self.show_scalar_bar:
            sbw.on()
        self._foreground_changed_for_scene(None, value.foreground)

    def _foreground_changed_for_scene(self, old, new):
        # Change the default color for the text.
        self.title_text_property.color = new
        self.label_text_property.color = new
        self.render()

    def _number_of_colors_changed(self, value):
        if self.lut_mode == 'file':
            return
        elif self.lut_mode in pylab_luts:
            # We can't interpolate these LUTs, as they are defined from a
            # table. We hack around this limitation
            reverse = self.reverse_lut
            lut = pylab_luts[self.lut_mode]
            if reverse:
                lut = lut[::-1, :]
            n_total = len(lut)
            if value > n_total:
                return
            lut = lut[::round(n_total/float(value))]
            self.load_lut_from_list(lut.tolist())
        else:
            lut = self.lut
            lut.number_of_table_values = value
            lut.modified()
            lut.build()
            self.render() # necessary to flush.
        sc_bar = self.scalar_bar
        sc_bar.maximum_number_of_colors = value
        sc_bar.modified()
        self.render()

    def _number_of_labels_changed(self, value):
        sc_bar = self.scalar_bar
        sc_bar.number_of_labels = value
        sc_bar.modified()
        self.render()

    def _file_name_changed(self, value):
        if self.lut_mode == 'file':
            self.load_lut_from_file(value)
        else:
            # This will automagically load the LUT from the file.
            self.lut_mode = 'file'

    def _reverse_lut_changed(self, value):
        # This will do the needful.
        self._lut_mode_changed(self.lut_mode)

    def _show_scalar_bar_changed(self, value):
        if self.scene is not None:
            # Without a title for scalar bar actor, vtkOpenGLTexture logs this:
            # Error: No scalar values found for texture input!
            if self.scalar_bar.title == '':
                self.scalar_bar.title = ' '
            self.scalar_bar_widget.enabled = value
            self.render()

    def _get_show_legend(self):
        return self.show_scalar_bar

    def _set_show_legend(self, value):
        old = self.show_scalar_bar
        if value != old:
            self.show_scalar_bar = value
            self.trait_property_changed('show_legend', old, value)

    def _shadow_changed(self, value):
        sc_bar = self.scalar_bar
        sc_bar.title_text_property.shadow = self.shadow
        sc_bar.label_text_property.shadow = self.shadow
        self.render()

    def _use_default_name_changed(self, value):
        self._default_data_name_changed(self.default_data_name)

    def _data_name_changed(self, value):
        sc_bar = self.scalar_bar
        sc_bar.title = value
        sc_bar.modified()
        self.render()

    def _default_data_name_changed(self, value):
        if self.use_default_name:
            self.data_name = value

    def _use_default_range_changed(self, value):
        self._default_data_range_changed(self.default_data_range)

    def _data_range_changed(self, value):
        try:
            self.lut.set_range(value[0], value[1])
        except TypeError:
            self.lut.set_range((value[0], value[1]))
        except AttributeError:
            self.lut.range = value
        self.scalar_bar.modified()
        self.render()

    def _default_data_range_changed(self, value):
        if self.use_default_range:
            self.data_range = value

    def _visible_changed(self, value):
        state = self.show_scalar_bar and value
        self._show_scalar_bar_changed(state)
        super(LUTManager, self)._visible_changed(value)

    def load_lut_from_file(self, file_name):
        lut_list = []
        if len(file_name) > 0:
            try:
                f = open(file_name, 'r')
            except IOError:
                msg = "Cannot open Lookup Table file: %s\n"%file_name
                error(msg)
            else:
                f.close()
                try:
                    lut_list = parse_lut_file(file_name)
                except IOError as err_msg:
                    msg = "Sorry could not parse LUT file: %s\n"%file_name
                    msg += err_msg
                    error(msg)
                else:
                    if self.reverse_lut:
                        lut_list.reverse()
                    self.lut = set_lut(self.lut, lut_list)
                    self.render()

    def load_lut_from_list(self, list):
        self.lut = set_lut(self.lut, list)
        self.render()

    def _get_title_text_property(self):
        return self._title_text_property

    def _get_label_text_property(self):
        return self._label_text_property

    def _create_lut_fired(self):
        from tvtk import util
        script = os.path.join(os.path.dirname(util.__file__),
                              'wx_gradient_editor.py')
        subprocess.Popen([sys.executable, script])
        auto_close_message('Launching LUT editor in separate process ...')

    def _scalar_bar_representation_default(self):
        w = self.scalar_bar_widget
        if hasattr(w, 'representation'):
            r = w.representation
            r.on_trait_change(self.render)
            return r
        else:
            return None
Esempio n. 15
0
class ParticleScanner(HasTraits):
    median_filter_width = Range(1, 31)
    threshold_block_size = Int(21)
    threshold_level = Range(-255, 255, 50, mode="slider")
    live_filter = Enum(["None", "Denoised", "Thresholded"])
    scan_current_view = Button()
    abort_scan_button = Button(label="abort_scan")
    scan_status = String("Not Scanning")
    scan_progress = Range(0., 100., 0.)
    scanning = Bool(False)
    tiled_scan_size = Array(shape=(2, ), dtype=np.int)
    start_tiled_scan = Button()
    border_pixels = 15

    traits_view = View(Tabbed(
        VGroup(
            Item(name="median_filter_width"),
            Item(name="threshold_block_size"),
            Item(name="threshold_level"),
            Item(name="live_filter"),
            label="Image Processing",
        ),
        VGroup(
            Item(name="scan_status", style="readonly"),
            Item(name="scan_progress", style="readonly"),
            Item(name="scanning", style="readonly"),
            Item(name="scan_current_view"),
            Item(name="abort_scan_button"),
            Item(name="tiled_scan_size"),
            Item(name="start_tiled_scan"),
            label="Scan Control",
        ),
    ),
                       title="Particle Scanner")
    """
    Find particles in an image and move to them
    """
    def __init__(self, camera_stage_mapper, spectrometer, spectrometer_aligner,
                 datafile):
        super(ParticleScanner, self).__init__()
        self.csm = camera_stage_mapper
        self.spectrometer = spectrometer
        self.datafile = datafile
        self.aligner = spectrometer_aligner
        self._live_filter_changed()  #enable video filter if required
        self._scan_lock = threading.Lock()
        self._abort_scan_event = threading.Event()

    def SendCompleteMessage(self, number):
        gmail_user = "******"
        gmail_pwd = "NQ3dPv6SXZUEdfTE"
        FROM = '*****@*****.**'
        TO = ['*****@*****.**']  #must be a list
        SUBJECT = "Scan finished"
        TEXT = "%d particles scanned" % number

        # Prepare actual message
        message = """\From: %s\nTo: %s\nSubject: %s\n\n%s
        """ % (FROM, ", ".join(TO), SUBJECT, TEXT)
        try:
            server = smtplib.SMTP("smtp.gmail.com",
                                  587)  #or port 465 doesn't seem to work!
            server.ehlo()
            server.starttls()
            server.login(gmail_user, gmail_pwd)
            server.sendmail(FROM, TO, message)
            server.close()
            print 'successfully sent the mail'
        except:
            print "failed to send mail"

    def denoise_image(self, img):
        """apply the current denoising filter to the image"""
        if (self.median_filter_width > 0):
            if self.median_filter_width % 2 == 0:
                self.median_filter_width += 1  #guard agains even integers!
            #return cv2.blur(img,self.median_filter_width)
            return cv2.blur(
                img, (self.median_filter_width, self.median_filter_width))
        else:
            return img

    def threshold_image(self, img):
        """apply threshold with the current settings to an image"""
        #return cv2.threshold(self.denoise_image(img),self.threshold_level,255,cv2.THRESH_BINARY)[1]
        img = cv2.adaptiveThreshold(
            img, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY,
            (int(self.threshold_block_size) / 2) * 2 + 1, self.threshold_level)
        kernel = np.ones((self.median_filter_width, self.median_filter_width),
                         np.uint8)
        return cv2.morphologyEx(
            img, cv2.MORPH_OPEN, kernel, iterations=1
        )  #after thresholding, erode then dilate to kill small blobs/noise

    def camera_filter_function(self, frame):
        img = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        if self.live_filter == "Denoised":
            img = self.denoise_image(img)
        elif self.live_filter == "Thresholded":
            img = self.threshold_image(self.denoise_image(img))
        return cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

    def _live_filter_changed(self):
        if self.live_filter == "None":
            self.csm.camera.filter_function = None
        else:
            self.csm.camera.filter_function = self.camera_filter_function

    @on_trait_change("find_particles")
    def find_particles_in_new_image(
        self
    ):  #necessary to stop extra arguments from Traits messing things up
        self.find_particles()

    def find_particles(self, img=None):
        """find particles in the supplied image, or in the camera image"""
        if img is None:
            ret, frame = self.csm.camera.raw_snapshot()
            img = self.threshold_image(
                self.denoise_image(cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)))[
                    self.border_pixels:-self.border_pixels,
                    self.border_pixels:-self.border_pixels]  #ignore the edges
        labels, nlabels = ndimage.measurements.label(img)
        return [
            np.array(p) + 15 for p in ndimage.measurements.center_of_mass(
                img, labels, range(1, nlabels + 1))
        ]  #add 15 onto all the positions

    def go_to_particles(self,
                        payload_function=lambda: time.sleep(2),
                        background=True,
                        max_n_particles=None):
        """Find particles, then visit each one in turn and execute a payload.
        
        This function returns immediately as it spawns a background thread. The
        scan can be monitored through traits scan_status, scan_progress, 
        scanning.  It can be aborted with the abort_scan() method.
        
        By default it simply waits for 2 seconds at each position.
        """
        if self.scanning:
            return

        def worker_function():
            if not self._scan_lock.acquire(False):
                raise Exception(
                    "Tried to start a scan, but one was in progress!")
            aborted = False
            self.scanning = True
            self.scan_progress = 0
            self.scan_status = "Setting up scan..."
            here = self.csm.camera_centre_position()
            pixel_positions = self.find_particles()
            positions = [
                self.csm.camera_pixel_to_sample(p) for p in pixel_positions
            ]
            image = self.csm.camera.color_image()
            feature_images = [
                image[p[0] - self.border_pixels:p[0] + self.border_pixels,
                      p[1] - self.border_pixels:p[1] + self.border_pixels]
                for p in pixel_positions
            ]  #extract feature images
            for index, p in enumerate(positions):
                if max_n_particles is not None and index >= max_n_particles:
                    print "Terminating scan as we've now scanned enough particles"
                    break
                self.scan_status = "Scanning particle %d of %d" % (
                    index, len(positions))
                self.csm.move_to_sample_position(p)
                time.sleep(0.3)
                self.csm.centre_on_feature(feature_images[index])
                payload_function()
                self.scan_progress = float(index) / float(len(positions)) * 100
                if self._abort_scan_event.is_set(
                ):  #this event lets us abort a scan
                    self.scan_status = "Scan Aborted."
                    self._abort_scan_event.clear()
                    aborted = True
                    break
            self.csm.move_to_sample_position(here)
            self.scan_status = "Scan Finished"
            self.scan_progress = 100.0
            print "Scan Finished :)"
            self.scanning = False
            self._scan_lock.release()
            return not aborted

        #execute the above function in the background
        if background:
            self._scan_thread = threading.Thread(target=worker_function)
            self._scan_thread.start()
        else:  #if we elected not to use a thread, just do it!
            return worker_function()

    @on_trait_change("abort_scan_button")
    def abort_scan(self):
        """Abort a currently-running scan in a background thread."""
        if self._scan_thread is not None and self._scan_thread.is_alive():
            self._abort_scan_event.set()


#        if self._scan_thread is not None:
#            self._scan_thread.join()
#        self._abort_scan_event.clear()

    def tile_scans(self,
                   size,
                   background=True,
                   tile_start_function=None,
                   ts_args=[],
                   ts_kwargs={},
                   *args,
                   **kwargs):
        def worker_function():
            grid_size = np.array(size)
            here = self.csm.camera_centre_position()
            scan_centres = [
                self.csm.camera_point_to_sample(
                    np.array([i, j]) - grid_size / 2)
                for i in range(grid_size[0])
                for j in (range(grid_size[1]) if i %
                          2 == 0 else reversed(range(grid_size[1]))
                          )  #snake-style raster scanning
            ]
            for centre in scan_centres:
                print "Taking a scan with centre %.1f, %.1f um" % tuple(centre)
                self.csm.move_to_sample_position(centre)
                if tile_start_function is not None:
                    tile_start_function(*ts_args, **ts_kwargs)
                ret = self.go_to_particles(background=False * args, **kwargs)
                if not ret:
                    print "Scan aborted!"
                    break
            self.csm.move_to_sample_position(here)
            print "Scan Finished!"
            latest_group = sorted(
                [
                    v for k, v in self.datafile['particleScans'].iteritems()
                    if 'scan' in k
                ],
                key=lambda g: int(re.search(r"(\d+)$", g.name).groups()[0]
                                  ))[-1]
            number_of_particles = len(
                [k for k in latest_group.keys() if 'z_scan_' in k])
            self.SendCompleteMessage(number_of_particles)

        #execute the above function in the background
        if background:
            self._scan_thread = threading.Thread(target=worker_function)
            self._scan_thread.start()
        else:  #if we elected not to use a thread, just do it!
            worker_function()

    def _scan_current_view_fired(self):
        self.take_zstacks_of_particles()

    def take_zstacks_of_particles(self,
                                  dz=np.arange(-2.5, 2.5, 0.2),
                                  datafile_group=None,
                                  *args,
                                  **kwargs):
        """visit each particle and scan it spectrally"""
        self.spectrometer.live_view = False
        g = self.new_data_group(
            "particleScans/scan%d",
            self.datafile) if datafile_group is None else datafile_group
        g.create_dataset("Raman_wavelengths", data=raman.GetWavelength())
        self.save_overview_images(g)
        self.go_to_particles(self.pf_align_and_take_z_scan(dz, g), *args,
                             **kwargs)

    def _start_tiled_scan_fired(self):
        self.take_zstacks_of_particles_tiled(self.tiled_scan_size)

    def take_zstacks_of_particles_tiled(self, shape, **kwargs):
        """Take z-stacked spectra of all the particles in several fields-of-view.
        
        We essentially run take_zstacks_of_particles for several fields of view,
        tiling them together in to the "shape" specified (2-element tuple). The
        centre of the tiled image is the current position.
        """
        self.spectrometer.live_view = False
        g = self.new_data_group("particleScans/scan%d", self.datafile)
        g.create_dataset("Raman_wavelengths", data=raman.GetWavelength())
        self.tile_scans(shape,
                        tile_start_function=self.save_overview_images,
                        ts_args=[g],
                        payload_function=self.pf_align_and_take_z_scan(
                            datafile_group=g, **kwargs))

    def new_data_group(self, name="particleScans/scan%d", parent=None):
        if parent is None: parent = self.datafile
        n = 0
        while name % n in parent:
            n += 1
        return parent.create_group(name % n)

    def new_dataset_name(self, g, name):
        n = 0
        while name % n in g:
            n += 1
        return name % n

    def save_overview_images(self, datafile_group):
        self.csm.autofocus_iterate(np.arange(-5, 5, 0.5))
        """save an unmodified and a thresholded image, as a reference for scans"""
        time.sleep(1)
        self.csm.camera.update_latest_frame()
        img1 = datafile_group.create_dataset(
            self.new_dataset_name(datafile_group, "overview_image_%d"),
            data=self.csm.camera.color_image())
        img1.attrs.create("stage_position", self.csm.stage.position())
        img1.attrs.create("camera_centre_position",
                          self.csm.camera_centre_position())
        img1.attrs.create("mapping_matrix_camera_to_sample",
                          self.csm.camera_to_sample)
        img1.attrs.create("timestamp", datetime.datetime.now().isoformat())
        img2 = datafile_group.create_dataset(
            self.new_dataset_name(datafile_group,
                                  "overview_image_%d_thresholded"),
            data=self.threshold_image(
                self.denoise_image(self.csm.camera.gray_image())))
        img2.attrs.create("stage_position", self.csm.stage.position())
        img2.attrs.create("camera_centre_position",
                          self.csm.camera_centre_position())
        for key, val in self.get(['median_filter_width',
                                  'threshold_level']).iteritems():
            img2.attrs.create(key, val)
        img2.attrs.create("camera_to_sample_matrix", self.csm.camera_to_sample)
        img2.attrs.create("timestamp", datetime.datetime.now().isoformat())

    def pf_align_and_take_z_scan(self,
                                 dz=np.arange(-4, 4, 0.4),
                                 datafile_group=None):
        """Set up for a scan of all particles, then return a payload function.
        
        The "payload function" is suitable for the eponymous argument of 
        go_to_particles, and will autofocus, align particle to fibre, and take
        a Z stack.  NB the payload function "wraps up" the arguments neatly so
        we don't need to store things like the depth of the Z stack.
        """
        if datafile_group is None:
            datafile_group = self.new_data_group("particleScans/scan%d",
                                                 self.datafile)

        def align_and_take_z_scan():
            # --- Initialize shutter positions by making sure they are all closed
            print("Matt: Initializing shutter positions.")
            light_shutter.close_shutter()
            raman.shutter.close_shutter()
            # --- Open white light shutter and close laser shutter
            print("Matt: Opening white light shutter.")
            light_shutter.open_shutter()
            # --- Fully open Shamrock slit
            print("Matt: Opening spectrometer slit.")
            raman.sham.SetSlit(2000)
            # --- Wait a bit for settings to be applied
            time.sleep(1)
            # --- Do Autofocus using OceanOptics Spectrum
            print("Matt: Doing autofocus.")
            self.csm.autofocus_iterate(np.arange(-2.5, 2.5, 0.5))
            self.aligner.spectrometer.integration_time = 300.  #short integration time for alignment
            self.aligner.optimise_2D(tolerance=0.03, stepsize=0.2)
            self.aligner.spectrometer.integration_time = 1000.  #long integration time for measurement
            # --- Initialize datafile
            print("Matt: Initializing datafile.")
            g = self.new_data_group("z_scan_%d", datafile_group)
            dset = g.create_dataset("z_scan", data=self.aligner.z_scan(dz))
            for key, val in self.aligner.spectrometer.get_metadata().iteritems(
            ):
                dset.attrs.create(key, val)
            dset.attrs.create("stage_position", self.csm.stage.position())
            dset.attrs.create("camera_centre_position",
                              self.csm.camera_centre_position())
            dset.attrs.create("timestamp", datetime.datetime.now().isoformat())
            dset.attrs.create("dz", dz)
            # --- Close all shutters
            print("Matt: Closing white light shutter.")
            light_shutter.close_shutter()
            # --- Set Infinity3 camera settings --- TODO !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            # --- Set Andor camera settings --- TODO !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            #we're going to take a picture - best make sure we've waited a moment for the focus to return
            time.sleep(0.3)
            # --- Take bias images on Infinity3, Andor, and OceanOptics
            # --- Infinity3
            print("Matt: Taking Infinity3 bias image.")
            self.csm.camera.update_latest_frame(
            )  #take a frame and ignore (for freshness)
            image = self.csm.camera.color_image()
            img = g.create_dataset("Infinity3_Bias_Image", data=image)
            img.attrs.create("stage_position", self.csm.stage.position())
            img.attrs.create("timestamp", datetime.datetime.now().isoformat())
            # --- Andor 0 order
            print("Matt: Taking Andor 0 order bias image.")
            raman.sham.GotoZeroOrder()
            time.sleep(5)
            image = np.reshape(raman.take_bkg(), (-1, raman.sham.pixel_number))
            wavelengths = raman.GetWavelength()
            rint = g.create_dataset("Raman_Bias_0Order_int", data=image)
            g.create_dataset("Raman_Bias_0Order_wl", data=wavelengths)
            rint.attrs.create("Laser power", raman.laser_power)
            rint.attrs.create("Slit size", raman.slit_size)
            rint.attrs.create("Integration time", raman.Integration_time)
            rint.attrs.create("description", raman.scan_desc)
            # --- Andor spectrum
            print("Matt: Taking Andor spectrum bias image.")
            raman.sham.SetWavelength(raman.centre_Wavelength)
            time.sleep(5)
            image = np.reshape(raman.take_bkg(), (-1, raman.sham.pixel_number))
            wavelengths = raman.GetWavelength()
            rint = g.create_dataset("Raman_Bias_Spectrum_int", data=image)
            g.create_dataset("Raman_Bias_Spectrum_wl", data=wavelengths)
            rint.attrs.create("Laser power", raman.laser_power)
            rint.attrs.create("Slit size", raman.slit_size)
            rint.attrs.create("Integration time", raman.Integration_time)
            rint.attrs.create("description", raman.scan_desc)
            # --- OceanOptics
            print("Matt: Taking OceanOptics bias image.")
            (oowl, oospec) = spectrometer.read()
            g.create_dataset("OOptics_Bias_Spectrum_int", data=oospec)
            g.create_dataset("OOptics_Bias_Spectrum_wl", data=oowl)
            # --- Turn on white light
            print("Matt: Turning white light back on.")
            light_shutter.open_shutter()
            # --- Take Infinity3 image
            print("Matt: Taking first Infinity3 white light image.")
            #we're going to take a picture - best make sure we've waited a moment for the focus to return
            time.sleep(0.3)
            self.csm.camera.update_latest_frame(
            )  #take a frame and ignore (for freshness)
            image = self.csm.camera.color_image()
            img = g.create_dataset(
                "Infinity3_FirstWhiteLight_Image",
                data=image[image.shape[0] / 2 - 50:image.shape[0] / 2 + 50,
                           image.shape[1] / 2 - 50:image.shape[1] / 2 + 50])
            img.attrs.create("stage_position", self.csm.stage.position())
            img.attrs.create("timestamp", datetime.datetime.now().isoformat())
            # --- Take white light spectrum (Ocean Optics) TODO !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            # --- Take white light spectrum (Andor)
            print("Matt: Taking white light spectrum on Andor")
            raman.sham.SetWavelength(raman.centre_Wavelength)
            time.sleep(5)
            image = np.reshape(raman.take_bkg(), (-1, raman.sham.pixel_number))
            wavelengths = raman.GetWavelength()
            rint = g.create_dataset("Raman_White_Light_Spectrum_int",
                                    data=image)
            g.create_dataset("Raman_White_Light_Spectrum_wl", data=wavelengths)
            rint.attrs.create("Laser power", raman.laser_power)
            rint.attrs.create("Slit size", raman.slit_size)
            rint.attrs.create("Integration time", raman.Integration_time)
            rint.attrs.create("description", raman.scan_desc)
            # --- Take white light image (Andor)
            print("Matt: Taking Andor 0 order white light image.")
            raman.sham.GotoZeroOrder()
            time.sleep(5)
            image = np.reshape(raman.take_bkg(), (-1, raman.sham.pixel_number))
            wavelengths = raman.GetWavelength()
            rint = g.create_dataset("Raman_White_Light_0Order_int", data=image)
            g.create_dataset("Raman_White_Light_0Order_wl", data=wavelengths)
            rint.attrs.create("Laser power", raman.laser_power)
            rint.attrs.create("Slit size", raman.slit_size)
            rint.attrs.create("Integration time", raman.Integration_time)
            rint.attrs.create("description", raman.scan_desc)
            # --- Turn off white light
            print("Matt: Closing white light shutter.")
            light_shutter.close_shutter()
            # --- Turn on laser
            print("Matt: Opening the laser shutter.")
            raman.shutter.open_shutter()
            # --- Set Infinity3 exposure/gain very low and image beam profile. Then restore old values
            oldExposure = cam.parameters[cam.parameters[0].list_names().index(
                'EXPOSURE')]._get_value()
            oldGain = cam.parameters[cam.parameters[0].list_names().index(
                'GAIN')]._get_value()
            cam.parameters[
                cam.parameters[0].list_names().index('EXPOSURE')]._set_value(
                    0)  #sometimes need to set to float(-inf)
            cam.parameters[cam.parameters[0].list_names().index(
                'GAIN')]._set_value(0)
            print("Matt: Taking Infinity3 image of laser beam profile.")
            #we're going to take a picture - best make sure we've waited a moment for the focus to return
            time.sleep(0.3)
            self.csm.camera.update_latest_frame(
            )  #take a frame and ignore (for freshness)
            image = self.csm.camera.color_image()
            img = g.create_dataset("Infinity3_Laser_Beam_Image", data=image)
            img.attrs.create("stage_position", self.csm.stage.position())
            img.attrs.create("timestamp", datetime.datetime.now().isoformat())
            cam.parameters[cam.parameters[0].list_names().index(
                'EXPOSURE')]._set_value(oldExposure)
            cam.parameters[cam.parameters[0].list_names().index(
                'GAIN')]._set_value(oldGain)
            # --- Take image of laser zero-order (Andor)
            print("Matt: Taking Andor 0 order laser image.")
            raman.sham.GotoZeroOrder()
            time.sleep(5)
            image = np.reshape(raman.take_bkg(), (-1, raman.sham.pixel_number))
            wavelengths = raman.GetWavelength()
            rint = g.create_dataset("Raman_Laser_0Order_int", data=image)
            g.create_dataset("Raman_Laser_0Order_wl", data=wavelengths)
            rint.attrs.create("Laser power", raman.laser_power)
            rint.attrs.create("Slit size", raman.slit_size)
            rint.attrs.create("Integration time", raman.Integration_time)
            rint.attrs.create("description", raman.scan_desc)
            # --- Take image of laser spectrum (Andor)
            print("Matt: Taking Andor spectrum laser image.")
            raman.sham.SetWavelength(raman.centre_Wavelength)
            time.sleep(5)
            image = np.reshape(raman.take_bkg(), (-1, raman.sham.pixel_number))
            wavelengths = raman.GetWavelength()
            rint = g.create_dataset("Raman_Laser_Spectrum_int", data=image)
            g.create_dataset("Raman_Laser_Spectrum_wl", data=wavelengths)
            rint.attrs.create("Laser power", raman.laser_power)
            rint.attrs.create("Slit size", raman.slit_size)
            rint.attrs.create("Integration time", raman.Integration_time)
            rint.attrs.create("description", raman.scan_desc)
            # --- Turn off laser
            print("Matt: Closing the laser shutter.")
            raman.shutter.close_shutter()
            # --- Open the white light shutter
            print("Matt: Turning white light back on.")
            light_shutter.open_shutter()
            # --- Take second Infinity3 white light image (to track drift)
            print("Matt: Taking second Infinity3 white light image.")
            #we're going to take a picture - best make sure we've waited a moment for the focus to return
            time.sleep(0.3)
            self.csm.camera.update_latest_frame(
            )  #take a frame and ignore (for freshness)
            image = self.csm.camera.color_image()
            img = g.create_dataset(
                "Infinity3_SecondWhiteLight_Image",
                data=image[image.shape[0] / 2 - 50:image.shape[0] / 2 + 50,
                           image.shape[1] / 2 - 50:image.shape[1] / 2 + 50])
            img.attrs.create("stage_position", self.csm.stage.position())
            img.attrs.create("timestamp", datetime.datetime.now().isoformat())
            # --- Move stage slightly to take background. (MAKE THIS ACTUALLY LOOK FOR A SPOT WITH NO PARTICLES) TODO !!!!!!!!!!!!!!!!
            self.csm.stage.move_rel([
                1, 0, 0
            ])  # Move the stage by one micron to a (hopefully) empty area
            # --- Take Infinity3 white light image (as evidence stage moved to a good location for background)
            print(
                "Matt: Taking first Infinity3 white light image for background location."
            )
            #we're going to take a picture - best make sure we've waited a moment for the focus to return
            time.sleep(0.3)
            self.csm.camera.update_latest_frame(
            )  #take a frame and ignore (for freshness)
            image = self.csm.camera.color_image()
            img = g.create_dataset(
                "Infinity3_FirstBkgndWhiteLight_Image",
                data=image[image.shape[0] / 2 - 50:image.shape[0] / 2 + 50,
                           image.shape[1] / 2 - 50:image.shape[1] / 2 + 50])
            img.attrs.create("stage_position", self.csm.stage.position())
            img.attrs.create("timestamp", datetime.datetime.now().isoformat())
            # --- Take white light spectrum at background location (Ocean Optics) TODO !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            # --- Take white light spectrum at background location (Andor)
            print(
                "Matt: Taking white light spectrum at background location on Andor"
            )
            raman.sham.SetWavelength(raman.centre_Wavelength)
            time.sleep(5)
            image = np.reshape(raman.take_bkg(), (-1, raman.sham.pixel_number))
            wavelengths = raman.GetWavelength()
            rint = g.create_dataset("Raman_White_Light_Bkgnd_Spectrum_int",
                                    data=image)
            g.create_dataset("Raman_White_Light_Bkgnd_Spectrum_wl",
                             data=wavelengths)
            rint.attrs.create("Laser power", raman.laser_power)
            rint.attrs.create("Slit size", raman.slit_size)
            rint.attrs.create("Integration time", raman.Integration_time)
            rint.attrs.create("description", raman.scan_desc)
            # --- Take white light image at background location (Andor)
            print(
                "Matt: Taking Andor 0 order white light image at background location."
            )
            raman.sham.GotoZeroOrder()
            time.sleep(5)
            image = np.reshape(raman.take_bkg(), (-1, raman.sham.pixel_number))
            wavelengths = raman.GetWavelength()
            rint = g.create_dataset("Raman_White_Light_Bkgnd_0Order_int",
                                    data=image)
            g.create_dataset("Raman_White_Light_0Order_Bkgnd_wl",
                             data=wavelengths)
            rint.attrs.create("Laser power", raman.laser_power)
            rint.attrs.create("Slit size", raman.slit_size)
            rint.attrs.create("Integration time", raman.Integration_time)
            rint.attrs.create("description", raman.scan_desc)
            # --- Turn off white light
            print("Matt: Closing white light shutter.")
            light_shutter.close_shutter()
            # --- Turn on laser
            print("Matt: Opening the laser shutter.")
            raman.shutter.open_shutter()
            # --- Set Infinity3 exposure/gain very low and image beam profile. Then restore old values
            oldExposure = cam.parameters[cam.parameters[0].list_names().index(
                'EXPOSURE')]._get_value()
            oldGain = cam.parameters[cam.parameters[0].list_names().index(
                'GAIN')]._get_value()
            cam.parameters[
                cam.parameters[0].list_names().index('EXPOSURE')]._set_value(
                    0)  #sometimes need to set to float(-inf)
            cam.parameters[cam.parameters[0].list_names().index(
                'GAIN')]._set_value(0)
            print(
                "Matt: Taking Infinity3 image of laser beam profile at background location."
            )
            #we're going to take a picture - best make sure we've waited a moment for the focus to return
            time.sleep(0.3)
            self.csm.camera.update_latest_frame(
            )  #take a frame and ignore (for freshness)
            image = self.csm.camera.color_image()
            img = g.create_dataset("Infinity3_Laser_Beam_Image_atBkgndLoc",
                                   data=image)
            img.attrs.create("stage_position", self.csm.stage.position())
            img.attrs.create("timestamp", datetime.datetime.now().isoformat())
            cam.parameters[cam.parameters[0].list_names().index(
                'EXPOSURE')]._set_value(oldExposure)
            cam.parameters[cam.parameters[0].list_names().index(
                'GAIN')]._set_value(oldGain)
            # --- Take image of laser zero-order at background location (Andor)
            print(
                "Matt: Taking Andor 0 order laser image at background location."
            )
            raman.sham.GotoZeroOrder()
            time.sleep(5)
            image = np.reshape(raman.take_bkg(), (-1, raman.sham.pixel_number))
            wavelengths = raman.GetWavelength()
            rint = g.create_dataset("Raman_Laser_0Order_atBkgndLoc_int",
                                    data=image)
            g.create_dataset("Raman_Laser_0Order_atBkgndLoc_wl",
                             data=wavelengths)
            rint.attrs.create("Laser power", raman.laser_power)
            rint.attrs.create("Slit size", raman.slit_size)
            rint.attrs.create("Integration time", raman.Integration_time)
            rint.attrs.create("description", raman.scan_desc)
            # --- Take image of laser spectrum at background location (Andor)
            print(
                "Matt: Taking Andor spectrum laser image at background location."
            )
            raman.sham.SetWavelength(raman.centre_Wavelength)
            time.sleep(5)
            image = np.reshape(raman.take_bkg(), (-1, raman.sham.pixel_number))
            wavelengths = raman.GetWavelength()
            rint = g.create_dataset("Raman_Laser_Spectrum_atBkgndLoc_int",
                                    data=image)
            g.create_dataset("Raman_Laser_Spectrum_atBkgndLoc_wl",
                             data=wavelengths)
            rint.attrs.create("Laser power", raman.laser_power)
            rint.attrs.create("Slit size", raman.slit_size)
            rint.attrs.create("Integration time", raman.Integration_time)
            rint.attrs.create("description", raman.scan_desc)
            # --- Turn off laser
            print("Matt: Closing the laser shutter.")
            raman.shutter.close_shutter()
            # --- Open the white light shutter
            print("Matt: Turning white light back on.")
            light_shutter.open_shutter()
            # --- Take second Infinity3 white light image at background location (to track drift)
            print(
                "Matt: Taking second Infinity3 white light image at background location."
            )
            #we're going to take a picture - best make sure we've waited a moment for the focus to return
            time.sleep(0.3)
            self.csm.camera.update_latest_frame(
            )  #take a frame and ignore (for freshness)
            image = self.csm.camera.color_image()
            img = g.create_dataset(
                "Infinity3_SecondWhiteLight_atBkgndLoc_Image",
                data=image[image.shape[0] / 2 - 50:image.shape[0] / 2 + 50,
                           image.shape[1] / 2 - 50:image.shape[1] / 2 + 50])
            img.attrs.create("stage_position", self.csm.stage.position())
            img.attrs.create("timestamp", datetime.datetime.now().isoformat())
            # --- Turn off all light sources
            print("Matt: Closing white light shutter.")
            light_shutter.close_shutter()
            raman.shutter.close_shutter()

            ###################################################################################################################
            print("Reached the end of Matt's code.")
            datafile_group.file.flush()
            sys.exit("Stopping here because Matt is doing some tests.")
            ###################################################################################################################

            ### take andor image 0 order and a spectrum
            "first bring the laser into focus"
            here = stage.position()
            laser_focus = raman.AlignHeNe(
                dset)  ### remove, keep, or create offset
            print "Moving to HeNe Focus (%g)" % (laser_focus)
            stage.move_rel([0, 0, laser_focus])
            time.sleep(1)

            #            raman.shutter.trigger() ##Command for opneing and closing shutter
            #  raman.cam.

            light_shutter.close_shutter()

            # --------------- CODE HERE

            "record fast kinetic raman scan"
            time.sleep(0.3)
            raman.take_fast_kinetic()  ###take spectrum and 0 order image

            "save the data to the HDF5 output file"
            Raman_spc = g.create_dataset("kinetic_raman",
                                         data=raman.kinetic_scan_data)
            Raman_spc.attrs.create("laser power", raman.laserpower)
            Raman_spc.attrs.create("integration time", raman.exptime)
            Raman_spc.attrs.create("focus height", laser_focus)
            Raman_spc.attrs.create("times", raman.times)
            Raman_spc.attrs.create("wavelengths", raman.Raman_wavelengths)

            light_shutter.open_shutter()

            # ----------------- /CODE HERE

            #===================================================================
            self.csm.autofocus_iterate(np.arange(-2.5, 2.5, 0.5))
            self.aligner.spectrometer.integration_time = 300.  #short integration time for alignment
            #self.aligner.optimise_2D(tolerance=0.07,stepsize=0.3)
            self.aligner.optimise_2D(tolerance=0.03, stepsize=0.2)
            #self.aligner.optimise_2D(tolerance=0.03,stepsize=0.1)
            self.aligner.spectrometer.integration_time = 1000.  #long integration time for measurement
            dset2 = g.create_dataset("z_scan2", data=self.aligner.z_scan(dz))
            for key, val in self.aligner.spectrometer.get_metadata().iteritems(
            ):
                dset2.attrs.create(key, val)
            dset2.attrs.create("stage_position", self.csm.stage.position())
            dset2.attrs.create("camera_centre_position",
                               self.csm.camera_centre_position())
            dset2.attrs.create("timestamp",
                               datetime.datetime.now().isoformat())
            dset2.attrs.create("dz", dz)
            #we're going to take a picture - best make sure we've waited a moment for the focus to return
            time.sleep(0.3)
            self.csm.camera.update_latest_frame(
            )  #take a frame and ignore (for freshness)
            image = self.csm.camera.color_image()
            img = g.create_dataset(
                "camera_image2",
                data=image[image.shape[0] / 2 - 50:image.shape[0] / 2 + 50,
                           image.shape[1] / 2 - 50:image.shape[1] / 2 + 50])
            img.attrs.create("stage_position", self.csm.stage.position())
            img.attrs.create("timestamp", datetime.datetime.now().isoformat())
            #===================================================================

            time.sleep(0.3)
            stage.move(here)

            datafile_group.file.flush()

        return align_and_take_z_scan

    def plot_latest_scan(self):
        """plot the spectra from the most recent scan"""
        g = self.latest_scan_group
        for name, scangroup in g.iteritems():
            if re.match(r"z_scan_\d+", name):
                scan = scangroup['z_scan']
                spectrum = np.sum(scan, 0)
                plt.plot(scan.attrs['wavelengths'],
                         spectrum / scan.shape[0] - scan.attrs['background'])
        plt.show(block=False)
Esempio n. 16
0
class TriangleWave(HasTraits):
    # 指定三角波的最窄和最宽范围,由于Range似乎不能将常数和traits名混用
    # 所以定义这两个不变的trait属性
    low = Float(0.02)
    hi = Float(1.0)

    # 三角波形的宽度
    wave_width = Range("low", "hi", 0.5)

    # 三角波的顶点C的x轴坐标
    length_c = Range("low", "wave_width", 0.5)

    # 三角波的定点的y轴坐标
    height_c = Float(1.0)

    # FFT计算所使用的取样点数,这里用一个Enum类型的属性以供用户从列表中选择
    fftsize = Enum( [(2**x) for x in range(6, 12)])

    # FFT频谱图的x轴上限值
    fft_graph_up_limit = Range(0, 400, 20)

    # 用于显示FFT的结果
    peak_list = Str

    # 采用多少个频率合成三角波
    N = Range(1, 40, 4)

    # 保存绘图数据的对象
    plot_data = Instance(AbstractPlotData)    

    # 绘制波形图的容器
    plot_wave = Instance(Component)

    # 绘制FFT频谱图的容器
    plot_fft  = Instance(Component)

    # 包括两个绘图的容器
    container = Instance(Component)

    # 设置用户界面的视图, 注意一定要指定窗口的大小,这样绘图容器才能正常初始化
    view = View(
        HSplit(
            VSplit(
                VGroup(
                    Item("wave_width", editor = scrubber, label=u"波形宽度"),
                    Item("length_c", editor = scrubber, label=u"最高点x坐标"),
                    Item("height_c", editor = scrubber, label=u"最高点y坐标"),
                    Item("fft_graph_up_limit", editor = scrubber, label=u"频谱图范围"),
                    Item("fftsize", label=u"FFT点数"),
                    Item("N", label=u"合成波频率数")
                ),
                Item("peak_list", style="custom", show_label=False, width=100, height=250)
            ),
            VGroup(
                Item("container", editor=ComponentEditor(size=(600,300)), show_label = False),
                orientation = "vertical"
            )
        ),
        resizable = True,
        width = 800,
        height = 600,
        title = u"三角波FFT演示"
    )

    # 创建绘图的辅助函数,创建波形图和频谱图有很多类似的地方,因此单独用一个函数以
    # 减少重复代码
    def _create_plot(self, data, name, type="line"):
        p = Plot(self.plot_data)
        p.plot(data, name=name, title=name, type=type)
        p.tools.append(PanTool(p))
        zoom = ZoomTool(component=p, tool_mode="box", always_on=False)
        p.overlays.append(zoom)        
        p.title = name
        return p

    def __init__(self):
        # 首先需要调用父类的初始化函数
        super(TriangleWave, self).__init__()

        # 创建绘图数据集,暂时没有数据因此都赋值为空,只是创建几个名字,以供Plot引用
        self.plot_data = ArrayPlotData(x=[], y=[], f=[], p=[], x2=[], y2=[]) 

        # 创建一个垂直排列的绘图容器,它将频谱图和波形图上下排列
        self.container = VPlotContainer()

        # 创建波形图,波形图绘制两条曲线: 原始波形(x,y)和合成波形(x2,y2)
        self.plot_wave = self._create_plot(("x","y"), "Triangle Wave")
        self.plot_wave.plot(("x2","y2"), color="red")

        # 创建频谱图,使用数据集中的f和p
        self.plot_fft  = self._create_plot(("f","p"), "FFT", type="scatter")

        # 将两个绘图容器添加到垂直容器中
        self.container.add( self.plot_wave )
        self.container.add( self.plot_fft )

        # 设置
        self.plot_wave.x_axis.title = "Samples"
        self.plot_fft.x_axis.title = "Frequency pins"
        self.plot_fft.y_axis.title = "(dB)"

        # 改变fftsize为1024,因为Enum的默认缺省值为枚举列表中的第一个值
        self.fftsize = 1024

    # FFT频谱图的x轴上限值的改变事件处理函数,将最新的值赋值给频谱图的响应属性
    def _fft_graph_up_limit_changed(self):
        self.plot_fft.x_axis.mapper.range.high = self.fft_graph_up_limit

    def _N_changed(self):
        self.plot_sin_combine()

    # 多个trait属性的改变事件处理函数相同时,可以用@on_trait_change指定
    @on_trait_change("wave_width, length_c, height_c, fftsize")        
    def update_plot(self):
        # 计算三角波
        global y_data
        x_data = np.arange(0, 1.0, 1.0/self.fftsize)
        func = self.triangle_func()
        # 将func函数的返回值强制转换成float64
        y_data = np.cast["float64"](func(x_data))

        # 计算频谱
        fft_parameters = np.fft.fft(y_data) / len(y_data)

        # 计算各个频率的振幅
        fft_data = np.clip(20*np.log10(np.abs(fft_parameters))[:self.fftsize/2+1], -120, 120)

        # 将计算的结果写进数据集
        self.plot_data.set_data("x", np.arange(0, self.fftsize)) # x坐标为取样点
        self.plot_data.set_data("y", y_data)
        self.plot_data.set_data("f", np.arange(0, len(fft_data))) # x坐标为频率编号
        self.plot_data.set_data("p", fft_data)

        # 合成波的x坐标为取样点,显示2个周期
        self.plot_data.set_data("x2", np.arange(0, 2*self.fftsize)) 

        # 更新频谱图x轴上限
        self._fft_graph_up_limit_changed()

        # 将振幅大于-80dB的频率输出
        peak_index = (fft_data > -80)
        peak_value = fft_data[peak_index][:20]
        result = []
        for f, v in zip(np.flatnonzero(peak_index), peak_value):
            result.append("%s : %s" %(f, v) )
        self.peak_list = "\n".join(result)

        # 保存现在的fft计算结果,并计算正弦合成波
        self.fft_parameters = fft_parameters
        self.plot_sin_combine()

    # 计算正弦合成波,计算2个周期
    def plot_sin_combine(self):
        index, data = fft_combine(self.fft_parameters, self.N, 2)
        self.plot_data.set_data("y2", data)               

    # 返回一个ufunc计算指定参数的三角波
    def triangle_func(self):
        c = self.wave_width
        c0 = self.length_c
        hc = self.height_c

        def trifunc(x):
            x = x - int(x) # 三角波的周期为1,因此只取x坐标的小数部分进行计算
            if x >= c: r = 0.0
            elif x < c0: r = x / c0 * hc
            else: r = (c-x) / (c-c0) * hc
            return r

        # 用trifunc函数创建一个ufunc函数,可以直接对数组进行计算, 不过通过此函数
        # 计算得到的是一个Object数组,需要进行类型转换
        return np.frompyfunc(trifunc, 1, 1)    
Esempio n. 17
0
class Glyph(Component):

    # The version of this class.  Used for persistence.
    __version__ = 0

    # Type of Glyph: 'tensor' or 'vector'
    glyph_type = Enum('vector',
                      'tensor',
                      desc='if the glyph is vector or tensor')

    # The scaling mode to use when scaling the glyphs.  We could have
    # used the glyph's own scale mode but it allows users to set the
    # mode to use vector components for the scaling which I'd like to
    # disallow.
    scale_mode = Trait(
        'scale_by_scalar',
        TraitRevPrefixMap({
            'scale_by_vector': 1,
            'scale_by_vector_components': 2,
            'data_scaling_off': 3,
            'scale_by_scalar': 0
        }),
        desc="if scaling is done using scalar or vector/normal magnitude")

    # The color mode to use when coloring the glyphs.  We could have
    # used the glyph's own color_mode trait but it allows users to set
    # the mode to use vector components for the scaling which I'd
    # like to disallow.
    color_mode = Trait(
        'color_by_scalar',
        TraitRevPrefixMap({
            'color_by_vector': 2,
            'color_by_scalar': 1,
            'no_coloring': 0
        }),
        desc="if coloring is done by scalar or vector/normal magnitude")
    color_mode_tensor = Trait(
        'scalar',
        TraitRevPrefixMap({
            'scalars': 1,
            'eigenvalues': 2,
            'no_coloring': 0
        }),
        desc="if coloring is done by scalar or eigenvalues")

    # Specify if the input points must be masked.  By mask we mean
    # that only a subset of the input points must be displayed.
    mask_input_points = Bool(False, desc="if input points are masked")

    # The MaskPoints filter.
    mask_points = Instance(tvtk.MaskPoints,
                           args=(),
                           kw={'random_mode': True},
                           record=True)

    # The Glyph3D instance.
    glyph = Instance(tvtk.Object, allow_none=False, record=True)

    # The Source to use for the glyph.  This is chosen from
    # `self._glyph_list` or `self.glyph_dict`.
    glyph_source = Instance(glyph_source.GlyphSource,
                            allow_none=False,
                            record=True)

    # The module associated with this component.  This is used to get
    # the data range of the glyph when the scale mode changes.  This
    # *must* be set if this module is to work correctly.
    module = Instance(Module)

    # Should we show the GUI option for changing the scalar mode or
    # not?  This is useful for vector glyphing modules where there it
    # does not make sense to scale the data based on scalars.
    show_scale_mode = Bool(True)

    ########################################
    # Private traits.

    # Used for optimization.
    _updating = Bool(False)

    ########################################
    # View related traits.

    view = View(Group(
        Item(name='mask_input_points'),
        Group(
            Item(name='mask_points',
                 enabled_when='object.mask_input_points',
                 style='custom',
                 resizable=True),
            show_labels=False,
        ),
        label='Masking',
    ),
                Group(
                    Group(
                        Item(name='scale_mode',
                             enabled_when='show_scale_mode',
                             visible_when='show_scale_mode'),
                        Item(name='color_mode',
                             enabled_when='glyph_type == "vector"',
                             visible_when='glyph_type == "vector"'),
                        Item(name='color_mode_tensor',
                             enabled_when='glyph_type == "tensor"',
                             visible_when='glyph_type == "tensor"'),
                    ),
                    Group(Item(name='glyph', style='custom', resizable=True),
                          show_labels=False),
                    label='Glyph',
                    selected=True,
                ),
                Group(
                    Item(name='glyph_source', style='custom', resizable=True),
                    show_labels=False,
                    label='Glyph Source',
                ),
                resizable=True)

    ######################################################################
    # `object` interface
    ######################################################################
    def __get_pure_state__(self):
        d = super(Glyph, self).__get_pure_state__()
        for attr in ('module', '_updating'):
            d.pop(attr, None)
        return d

    ######################################################################
    # `Module` interface
    ######################################################################
    def setup_pipeline(self):
        """Override this method so that it *creates* the tvtk
        pipeline.

        This method is invoked when the object is initialized via
        `__init__`.  Note that at the time this method is called, the
        tvtk data pipeline will *not* yet be setup.  So upstream data
        will not be available.  The idea is that you simply create the
        basic objects and setup those parts of the pipeline not
        dependent on upstream sources and filters.  You should also
        set the `actors` attribute up at this point.
        """
        self._glyph_type_changed(self.glyph_type)
        self.glyph_source = glyph_source.GlyphSource()

        # Handlers to setup our source when the sources pipeline changes.
        self.glyph_source.on_trait_change(self._update_source,
                                          'pipeline_changed')
        self.mask_points.on_trait_change(self.render)

    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.

        This method is invoked (automatically) when any of the inputs
        sends a `pipeline_changed` event.
        """
        if ((len(self.inputs) == 0) or (len(self.inputs[0].outputs) == 0)):
            return

        self._mask_input_points_changed(self.mask_input_points)
        if self.glyph_type == 'vector':
            self._color_mode_changed(self.color_mode)
        else:
            self._color_mode_tensor_changed(self.color_mode_tensor)
        self._scale_mode_changed(self.scale_mode)

        # Set our output.
        tvtk_common.configure_outputs(self, self.glyph)
        self.pipeline_changed = True

    def update_data(self):
        """Override this method so that it flushes the vtk pipeline if
        that is necessary.

        This method is invoked (automatically) when any of the inputs
        sends a `data_changed` event.
        """
        self._scale_mode_changed(self.scale_mode)
        self.data_changed = True

    def render(self):
        if not self._updating:
            super(Glyph, self).render()

    def start(self):
        """Overridden method.
        """
        if self.running:
            return
        self.glyph_source.start()
        super(Glyph, self).start()

    def stop(self):
        if not self.running:
            return
        self.glyph_source.stop()
        super(Glyph, self).stop()

    def has_output_port(self):
        """ The filter has an output port."""
        return True

    def get_output_object(self):
        """ Returns the output port."""
        return self.glyph.output_port

    ######################################################################
    # Non-public methods.
    ######################################################################
    def _update_source(self):
        self.configure_source_data(self.glyph, self.glyph_source.outputs[0])

    def _glyph_source_changed(self, value):
        self.configure_source_data(self.glyph, value.outputs[0])

    def _color_mode_changed(self, value):
        if len(self.inputs) == 0:
            return
        if value != 'no_coloring':
            self.glyph.color_mode = value

    def _color_mode_tensor_changed(self, value):
        if len(self.inputs) == 0:
            return
        self._updating = True
        if value != 'no_coloring':
            self.glyph.color_mode = value
            self.glyph.color_glyphs = True
        else:
            self.glyph.color_glyphs = False
        self._updating = False
        self.render()

    def _scale_mode_changed(self, value):
        if (self.module is None) or (len(self.inputs) == 0)\
                                 or self.glyph_type == 'tensor':
            return

        self._updating = True
        try:
            glyph = self.glyph
            glyph.scale_mode = value

            mm = self.module.module_manager
            if glyph.scale_mode == 'scale_by_scalar':
                glyph.range = tuple(mm.scalar_lut_manager.data_range)
            else:
                glyph.range = tuple(mm.vector_lut_manager.data_range)
        finally:
            self._updating = False
            self.render()

    def _mask_input_points_changed(self, value):
        inputs = self.inputs
        if len(inputs) == 0:
            return
        if value:
            mask = self.mask_points
            tvtk_common.configure_input(mask, inputs[0].outputs[0])
        else:
            self.configure_connection(self.glyph, inputs[0])

    def _glyph_type_changed(self, value):
        if self.glyph_type == 'vector':
            self.glyph = tvtk.Glyph3D(clamping=True)
        else:
            self.glyph = tvtk.TensorGlyph(scale_factor=0.1)
            self.show_scale_mode = False
        self.glyph.on_trait_change(self.render)

    def _scene_changed(self, old, new):
        super(Glyph, self)._scene_changed(old, new)
        self.glyph_source.scene = new
Esempio n. 18
0
class SvgRangeSelectionOverlay(StatusLayer):
    """ This is a primitive range selection overlay which uses
        a SVG to define the overlay.

        TODO: not inherit from StatusLayer, this was a convenience for a
            quick prototype

        TODO: use 2 svgs, one which defines the border and does not scale, and
            the other which defines the fill.
    """

    filename = os.path.join(os.path.dirname(__file__), 'data',
                            'range_selection.svg')

    alpha = 0.5

    # The axis to which this tool is perpendicular.
    axis = Enum("index", "value")

    axis_index = Property(depends_on='axis')

    # Mapping from screen space to data space. By default, it is just
    # self.component.
    plot = Property(depends_on='component')

    # The mapper (and associated range) that drive this RangeSelectionOverlay.
    # By default, this is the mapper on self.plot that corresponds to self.axis.
    mapper = Property(depends_on='plot')

    # The name of the metadata to look at for dataspace bounds. The metadata
    # can be either a tuple (dataspace_start, dataspace_end) in "selections" or
    # a boolean array mask of seleted dataspace points with any other name
    metadata_name = Str("selections")

    def overlay(self, component, gc, view_bounds=None, mode="normal"):
        """ Draws this component overlaid on another component.

        Overrides AbstractOverlay.
        """
        # Draw the selection
        coords = self._get_selection_screencoords()

        if len(coords) == 0:
            return

        with gc:
            gc.set_alpha(self.alpha)

            plot_width = self.component.width
            plot_height = self.component.height

            origin_x = self.component.padding_left
            origin_y = self.component.padding_top

            if self.axis == 'index':
                if isinstance(self.mapper, GridMapper):
                    scale_width = (coords[-1][0] -
                                   coords[0][0]) / self.doc_width
                else:
                    scale_width = (coords[0][-1] -
                                   coords[0][0]) / self.doc_width
                scale_height = float(plot_height) / self.doc_height
                gc.translate_ctm(coords[0][0], origin_y + plot_height)
            else:
                scale_height = (coords[0][-1] - coords[0][0]) / self.doc_height
                scale_width = float(plot_width) / self.doc_width
                gc.translate_ctm(origin_x, coords[0][0])

            # SVG origin is the upper right with y positive down, so
            # we need to flip everything
            gc.scale_ctm(scale_width, -scale_height)

            self.document.render(gc)

            self._draw_component(gc, view_bounds, mode)

        return

    def _get_selection_screencoords(self):
        """ Returns a tuple of (x1, x2) screen space coordinates of the start
        and end selection points.

        If there is no current selection, then returns None.
        """
        ds = getattr(self.plot, self.axis)
        selection = ds.metadata[self.metadata_name]

        # "selections" metadata must be a tuple
        if self.metadata_name == "selections":
            if selection is not None and len(selection) == 2:
                return [self.mapper.map_screen(numpy.array(selection))]
            else:
                return []
        # All other metadata is interpreted as a mask on dataspace
        else:
            ar = numpy.arange(0, len(selection), 1)
            runs = arg_find_runs(ar[selection])
            coords = []
            for inds in runs:
                start = ds._data[ar[selection][inds[0]]]
                end = ds._data[ar[selection][inds[1] - 1]]
                coords.append(self.map_screen(numpy.array((start, end))))
            return coords

    @cached_property
    def _get_plot(self):
        return self.component

    @cached_property
    def _get_axis_index(self):
        if self.axis == 'index':
            return 0
        else:
            return 1

    @cached_property
    def _get_mapper(self):
        # If the plot's mapper is a GridMapper, return either its
        # x mapper or y mapper

        mapper = getattr(self.plot, self.axis + "_mapper")

        if isinstance(mapper, GridMapper):
            if self.axis == 'index':
                return mapper._xmapper
            else:
                return mapper._ymapper
        else:
            return mapper
Esempio n. 19
0
from __future__ import absolute_import

from operator import itemgetter

from traits.api import BaseTraitHandler, CTrait, Enum, TraitError

from .ui_traits import SequenceTypes
import six

# -------------------------------------------------------------------------
#  Trait definitions:
# -------------------------------------------------------------------------

# Layout orientation for a control and its associated editor
Orientation = Enum("horizontal", "vertical")

# Docking drag bar style:
DockStyle = Enum("horizontal", "vertical", "tab", "fixed")


def user_name_for(name):
    """ Returns a "user-friendly" name for a specified trait.
    """
    name = name.replace("_", " ")
    name = name[:1].upper() + name[1:]
    result = ""
    last_lower = 0
    for c in name:
        if c.isupper() and last_lower:
            result += " "
Esempio n. 20
0
class Button(Component):

    color = ColorTrait("lightblue")

    down_color = ColorTrait("darkblue")

    border_color = ColorTrait("blue")

    label = Str

    label_font = KivaFont("modern 12")

    label_color = ColorTrait("white")

    down_label_color = ColorTrait("white")

    button_state = Enum("up", "down")

    # A reference to the radio group that this button belongs to
    radio_group = Any

    # Default size of the button if no label is present
    bounds=[32,32]

    # Generally, buttons are not resizable
    resizable = ""

    _got_mousedown = Bool(False)

    def perform(self, event):
        """
        Called when the button is depressed.  'event' is the Enable mouse event
        that triggered this call.
        """
        pass

    def _draw_mainlayer(self, gc, view_bounds, mode="default"):
        if self.button_state == "up":
            self.draw_up(gc, view_bounds)
        else:
            self.draw_down(gc, view_bounds)
        return

    def draw_up(self, gc, view_bounds):
        with gc:
            gc.set_fill_color(self.color_)
            gc.set_stroke_color(self.border_color_)
            gc.draw_rect((int(self.x), int(self.y), int(self.width)-1, int(self.height)-1), FILL_STROKE)
            self._draw_label(gc)
        return

    def draw_down(self, gc, view_bounds):
        with gc:
            gc.set_fill_color(self.down_color_)
            gc.set_stroke_color(self.border_color_)
            gc.draw_rect((int(self.x), int(self.y), int(self.width)-1, int(self.height)-1), FILL_STROKE)
            self._draw_label(gc, color=self.down_label_color_)
        return

    def _draw_label(self, gc, color=None):
        if self.label != "":
            gc.set_font(self.label_font)
            x,y,w,h = gc.get_text_extent(self.label)
            if color is None:
                color = self.label_color_
            gc.set_fill_color(color)
            gc.set_stroke_color(color)
            gc.show_text(self.label, (self.x+(self.width-w-x)/2,
                                  self.y+(self.height-h-y)/2))
        return

    def normal_left_down(self, event):
        self.button_state = "down"
        self._got_mousedown = True
        self.request_redraw()
        event.handled = True
        return

    def normal_left_up(self, event):
        self.button_state = "up"
        self._got_mousedown = False
        self.request_redraw()
        self.perform(event)
        event.handled = True
        return
Esempio n. 21
0
class PlotCorr(HasTraits):

    data = Array()

    plot_type = Enum('CMap', 'Overlay', 'Vertical', 'Multiline')
    p_title = Str('Graph')
    x_lbl = Str('Time')
    y_lbl_type = Enum('Corr', 'Single', 'Custom')
    y_lbl = Str('Corr')
    y_labels = List(Str)
    scale_type = Enum('Time', 'default')
    first_day = Date()
    apply_btn = Button('Apply')
    save_btn = Button('Save')

    y_low = Float(100.0)
    y_high = Float(-1.0)

    multi_line_plot_renderer = Instance(MultiLinePlot)
    # Drives multi_line_plot_renderer.normalized_amplitude
    amplitude = Range(-20.0, 20.0, value=1)
    # Drives multi_line_plot_renderer.offset
    offset = Range(-15.0, 15.0, value=0)

    plot = Instance(Component)

    plot_vertical = Property(Instance(Component))
    plot_overlay = Property(Instance(Component))
    plot_cmap = Property(Instance(Component))
    plot_multiline = Property(Instance(Component))

    view = View(Tabbed(
        VGroup(Item('plot',
                    editor=ComponentEditor(),
                    width=600,
                    height=400,
                    show_label=False),
               Item('amplitude',
                    label='amp',
                    visible_when="plot_type=='Multiline'"),
               Item('offset',
                    label='offset',
                    visible_when="plot_type=='Multiline'"),
               Item('save_btn', show_label=False, width=200, height=100),
               label='Graph'),
        VGroup(Item('plot_type', style='custom'),
               Item('p_title', label='graph title'),
               Item('x_lbl', label='x axis label'),
               Item('y_lbl_type', style='custom'),
               Item('y_lbl',
                    label='y axis label',
                    visible_when="y_lbl_type=='Single'"),
               Item('y_labels',
                    label='y axis labels',
                    visible_when="y_lbl_type=='Custom'"),
               Item('scale_type', style='custom'),
               Item('first_day', enabled_when="scale_type=='Time'"),
               Item('apply_btn', show_label=False, width=200, height=100),
               label='Config',
               springy=True)),
                resizable=True)

    ###########################################################################
    # Protected interface.
    ###########################################################################

    def _create_dates(self, numpoints, start=None, units="days"):
        """ Returns **numpoints** number of dates that evenly bracket the current
        date and time.  **units** should be one of "weeks", "days", "hours"
        "minutes", or "seconds".
        """
        units_map = {
            "weeks": 7 * 24 * 3600,
            "days": 24 * 3600,
            "hours": 3600,
            "minutes": 60,
            "seconds": 1
        }

        if start is None:
            start = time.time()  # Now
        else:
            start = time.mktime(start.timetuple())

        dt = units_map[units]
        dates = np.linspace(start, start + numpoints * dt, numpoints)
        return dates

    def _apply_btn_fired(self):
        self._update_plot()

    def _save_btn_fired(self):

        filter = 'PNG file (*.png)|*.png|\nTIFF file (*.tiff)|*.tiff|'
        dialog = FileDialog(action='save as', wildcard=filter)

        if dialog.open() != OK:
            return

        filename = dialog.path

        width, height = self.plot.outer_bounds

        gc = PlotGraphicsContext((width, height), dpi=100)
        gc.render_component(self.plot)
        try:
            gc.save(filename)
        except KeyError, e:
            errmsg = ("The filename must have an extension that matches "
                      "a graphics format, such as '.png' or '.tiff'.")
            if str(e.message) != '':
                errmsg = ("Unknown filename extension: '%s'\n" %
                          str(e.message)) + errmsg

            error(None, errmsg, title="Invalid Filename Extension")
Esempio n. 22
0
class PlotApp2(HasTraits):
    numPcaScores = PCA.nums
    plotdata = Instance(ArrayPlotData)

    Y_PCA = Str
    YPCA = List(Str)

    X_PCA = Str
    XPCA = List(Str)

    Color = Str
    Colordropdownlist = List(Str)
    colors = List(str)

    Shape = Str
    Shapedropdownlist = List(Str)

    Size = Str
    Sizedropdownlist = List(Str)

    shapes_name = List(Str)
    colors_name = List(Str)
    sizes_name = List(Str)

    active_scores_combobox = Enum(['Post Scores', 'Pre Scores'])
    start_selection = Button(label='Start Selection')
    stop_selection = Button(label='Stop Selection')

    RightPlot = Instance(OverlayPlotContainer)
    LeftPlot = Instance(OverlayPlotContainer)

    button_editor = ButtonEditor()

    table = List(Instance(MyData))
    columns = [ObjectColumn(name='name')]
    columns.append(ObjectColumn(name="Value"))
    table_editor = TableEditor(columns=columns,
                               deletable=True,
                               sortable=False,
                               sort_model=False,
                               show_lines=True,
                               line_color="black",
                               editable=False,
                               show_column_labels=False)

    shape_table = List(Instance(ShapeTable))
    shape_columns = List(Instance(ObjectColumn))

    color_table = List(Instance(ColorTable))
    color_columns = List(Instance(ObjectColumn))

    size_table = List(Instance(SizeTable))
    size_columns = List(Instance(ObjectColumn))

    traits_view = View(VSplit(
        HSplit(
            VGroup(
                VGroup(
                    Item(
                        'Y_PCA',
                        editor=EnumEditor(
                            name='YPCA',
                            evaluate=validate_choice,
                        ),
                    ),
                    Item(
                        'X_PCA',
                        editor=EnumEditor(
                            name='XPCA',
                            evaluate=validate_choice,
                        ),
                    ),
                    Item('active_scores_combobox', width=225, label="Score"),
                    HGroup(
                        Item('start_selection',
                             editor=button_editor,
                             show_label=False,
                             width=0.5),
                        Item('stop_selection',
                             editor=button_editor,
                             show_label=False,
                             width=0.5)),
                ),
                Item('LeftPlot',
                     editor=ComponentEditor(),
                     show_label=False,
                     width=590,
                     height=800),
            ),
            VSplit(
                HGroup(
                    VGroup(
                        Item(
                            'Shape',
                            editor=EnumEditor(
                                name='Shapedropdownlist',
                                evaluate=validate_choice,
                            ),
                        ),
                        Item('shape_table',
                             editor=TableEditor(columns_name='shape_columns',
                                                deletable=True,
                                                sortable=False,
                                                sort_model=False,
                                                show_lines=True,
                                                line_color="black",
                                                editable=False,
                                                show_column_labels=False),
                             show_label=False,
                             width=0.3,
                             padding=5)),
                    VGroup(
                        Item(
                            'Color',
                            editor=EnumEditor(
                                name='Colordropdownlist',
                                evaluate=validate_choice,
                            ),
                        ),
                        Item('color_table',
                             editor=TableEditor(columns_name='color_columns',
                                                deletable=True,
                                                sortable=False,
                                                sort_model=False,
                                                show_lines=True,
                                                line_color="black",
                                                editable=False,
                                                show_column_labels=False),
                             show_label=False,
                             width=0.3,
                             padding=5)),
                    VGroup(
                        Item(
                            'Size',
                            editor=EnumEditor(
                                name='Sizedropdownlist',
                                evaluate=validate_choice,
                            ),
                        ),
                        Item('size_table',
                             editor=TableEditor(columns_name='size_columns',
                                                deletable=True,
                                                sortable=False,
                                                sort_model=False,
                                                show_lines=True,
                                                line_color="black",
                                                editable=False,
                                                show_column_labels=False),
                             show_label=False,
                             width=0.3,
                             padding=5)),
                ),
                Item('RightPlot',
                     editor=ComponentEditor(),
                     show_label=False,
                     height=640),
            )), Item('table',
                     editor=table_editor,
                     show_label=False,
                     padding=15)),
                       width=1100,
                       height=700,
                       resizable=True,
                       title="Principal Components Visualizer")

    def __init__(self, PCAData):
        super(PlotApp2, self).__init__()
        #self.phenotypes, self.pheno_dict = readPhenotypesFromCSVFile('..\..\IOdata\pca_phenotypes.csv')
        self.phenotypes, self.pheno_dict = readPhenotypesFromCSVFile_pd(
            '..\IOdata\phenotypes_table_2.csv')
        print(self.phenotypes, self.pheno_dict)
        self.shapes_name = self.pheno_dict[self.phenotypes[1]]
        self.colors_name = self.pheno_dict[self.phenotypes[2]]
        self.sizes_name = self.pheno_dict[self.phenotypes[3]]
        self.colors = get_colors(len(self.colors_name))
        #print('self.color=',self.colors)

        self.table_editor.columns = [ObjectColumn(name='name')]
        for i in range(len(PCAData) - 1):
            self.table_editor.columns.append(ObjectColumn(name="PCA" + str(i)))

        self.PCAData = PCAData
        self.YPCA = [str(i) for i in range(len(PCAData) - 1)]
        self.XPCA = [str(i) for i in range(len(PCAData) - 1)]

        self.Shapedropdownlist = self.phenotypes
        self.Colordropdownlist = self.phenotypes
        self.Sizedropdownlist = self.phenotypes

        self.X_PCA = '0'
        self.Y_PCA = '0'

        self.Shape = self.phenotypes[1]
        self.Color = self.phenotypes[2]
        self.Size = self.phenotypes[3]

        self.activeScore = 'Pre Scores'
        self._updateTable()
        self._updateShapeTable()
        self._updateColorTable()
        self._updateSizeTable()
        self._update_Both_graph()

        return

    def _getPCAArray(self, pcaIndex):
        x0 = []
        for batch in self.PCAData.batchs:
            if (self.active_scores_combobox == "Post Scores"):
                x0.append(batch.postscores[pcaIndex])
            else:
                x0.append(batch.prescores[pcaIndex])
        return x0

    def _create_1D1_plot(self):
        index = 0
        plot0 = Plot(self.plotdata, padding=0)
        plot0.padding_left = 5
        plot0.padding_bottom = 5
        Container = OverlayPlotContainer(padding=50,
                                         fill_padding=True,
                                         bgcolor="lightgray",
                                         use_backbuffer=True)

        y1 = range(len(self.PCAData.batchs[0].prescores))
        points = []
        for batch in self.PCAData.batchs:
            if (self.active_scores_combobox == "Post Scores"):
                x1 = self.PCAData.batchs[index].postscores
            else:
                x1 = self.PCAData.batchs[index].prescores

            if (self.Shape == self.phenotypes[0]):
                a = 1
            elif (self.Shape == self.phenotypes[1]):
                a = batch.number
            elif (self.Shape == self.phenotypes[2]):
                a = batch.type
            else:
                a = 0

            if (self.Color == self.phenotypes[0]):
                b = 0
            elif (self.Color == self.phenotypes[1]):
                b = batch.number
            elif (self.Color == self.phenotypes[2]):
                b = batch.type
            else:
                b = 0

            tmarker = shapes[a]
            bcolor = self.colors[b]

            for i in range(len(x1)):
                points.append((x1[i], y1[i]))
            plot0 = create_scatter_plot((x1, y1),
                                        marker=tmarker,
                                        color=getColor(bcolor))

            if batch.isSelected:
                plot0.alpha = 1
            else:
                plot0.alpha = 0.2

            plot0.bgcolor = "white"
            plot0.border_visible = True

            if index == 0:
                value_mapper = plot0.value_mapper
                index_mapper = plot0.index_mapper
                add_default_grids(plot0)
                add_default_axes(plot0,
                                 vtitle='PCA Indices',
                                 htitle='PCA Scores')
                plot0.index_range.tight_bounds = False
                plot0.index_range.refresh()
                plot0.value_range.tight_bounds = False
                plot0.value_range.refresh()
                plot0.tools.append(PanTool(plot0))
                zoom = ZoomTool(plot0,
                                tool_mode="box",
                                always_on=False,
                                maintain_aspect_ratio=False)
                plot0.overlays.append(zoom)
                dragzoom = DragZoom(plot0,
                                    drag_button="right",
                                    maintain_aspect_ratio=False)
                plot0.tools.append(dragzoom)

            else:
                plot0.value_mapper = value_mapper
                value_mapper.range.add(plot0.value)
                plot0.index_mapper = index_mapper
                index_mapper.range.add(plot0.index)

            Container.add(plot0)
            index = index + 1

        self.RightPlot = Container

    def _create_2D_plot(self):

        index = 0
        secContainer = OverlayPlotContainer(padding=50,
                                            fill_padding=True,
                                            bgcolor="lightgray",
                                            use_backbuffer=True)
        try:
            pcaPoints = []
            for batch in self.PCAData.batchs:
                if (self.active_scores_combobox == "Post Scores"):
                    y = [batch.postscores[int(self.Y_PCA)]]
                    x = [batch.postscores[int(self.X_PCA)]]
                else:
                    x = [batch.prescores[int(self.X_PCA)]]
                    y = [batch.prescores[int(self.Y_PCA)]]
                for i in range(len(x)):
                    pcaPoints.append((x[i], y[i]))

                if (self.Shape == self.phenotypes[0]):
                    a = 1
                elif (self.Shape == self.phenotypes[1]):
                    a = batch.number
                elif (self.Shape == self.phenotypes[2]):
                    a = batch.type
                else:
                    a = 0

                if (self.Color == self.phenotypes[0]):
                    b = 0
                elif (self.Color == self.phenotypes[1]):
                    b = batch.number
                elif (self.Color == self.phenotypes[2]):
                    b = batch.type
                else:
                    b = 0

                tmarker = shapes[a]
                bcolor = self.colors[b]

                plot = create_scatter_plot((x, y),
                                           marker=tmarker,
                                           color=getColor(bcolor))
                if batch.isSelected:
                    plot.alpha = 1
                else:
                    plot.fill_alpha = 0.2
                plot.bgcolor = "white"
                plot.border_visible = True

                if index == 0:
                    value_mapper = plot.value_mapper
                    index_mapper = plot.index_mapper
                    add_default_grids(plot)
                    add_default_axes(plot,
                                     vtitle='PCA ' + self.Y_PCA,
                                     htitle='PCA ' + self.X_PCA)
                    plot.index_range.tight_bounds = False
                    plot.index_range.refresh()
                    plot.value_range.tight_bounds = False
                    plot.value_range.refresh()

                    plot.tools.append(PanTool(plot))
                    zoom = ZoomTool(plot, tool_mode="box", always_on=False)
                    plot.overlays.append(zoom)
                    dragzoom = DragZoom(plot, drag_button="right")
                    plot.tools.append(dragzoom)

                else:
                    plot.value_mapper = value_mapper
                    value_mapper.range.add(plot.value)
                    plot.index_mapper = index_mapper
                    index_mapper.range.add(plot.index)

                secContainer.add(plot)
                index = index + 1
            lineDraw = LineDrawer2D(plot)
            lineDraw.setPCAData(self, self.PCAData)
            plot.overlays.append(lineDraw)
            self.LeftPlot = secContainer
        except ValueError:

            pass

    def _Y_PCA_changed(self, selectedValue):
        self.Y_PCA = selectedValue
        self._create_2D_plot()

    def _X_PCA_changed(self, selectedValue):
        self.X_PCA = selectedValue
        self._create_2D_plot()

    def _Color_changed(self, selectedValue):
        self.pcolor = selectedValue
        self.colors_name = self.pheno_dict[self.pcolor]
        self.colors = get_colors(len(self.colors_name))
        #print(self.Color, self.colors_name, self.colors)
        self._updateColorTable()
        self._update_Both_graph()

    def _Size_changed(self, selectedValue):
        self.psize = selectedValue
        self.sizes_name = self.pheno_dict[self.psize]
        self._updateSizeTable()
        self._update_Both_graph()

    def _Shape_changed(self, selectedValue):
        self.pshape = selectedValue
        self.shapes_name = self.pheno_dict[self.pshape]
        self._updateShapeTable()
        self._update_Both_graph()

    def _active_scores_combobox_changed(self):
        self._update_Both_graph()

    def _start_selection_fired(self):
        for batch in self.PCAData.batchs:
            batch.isSelected = False
        self._create_1D1_plot()
        self._create_2D_plot()

    def _stop_selection_fired(self):
        for batch in self.PCAData.batchs:
            batch.isSelected = True
        self._create_1D1_plot()
        self._create_2D_plot()

    def _updateShapeTable(self):
        del (self.shape_table)

        columns = [ObjectColumn(name='name')]
        for i in range(len(self.shapes_name)):
            columns.append(ObjectColumn(name='s' + self.shapes_name[i]))
        data = ShapeTable()
        self.shape_table.append(data)
        self.shape_columns = columns
        self.shape_table.remove(data)

        data = ShapeTable()
        data.name = self.pshape
        for i in range(len(self.shapes_name)):
            exec('data.s' + self.shapes_name[i] + '="' + self.shapes_name[i] +
                 '"')
        self.shape_table.append(data)

        data = ShapeTable()
        data.name = "Shape"
        for i in range(len(self.shapes_name)):
            exec('data.s' + self.shapes_name[i] + '="' + shapes[i] + '"')
        self.shape_table.append(data)

    def _updateColorTable(self):
        del (self.color_table)
        columns = [ObjectColumn(name='name')]
        for i in range(len(self.colors_name)):
            columns.append(
                ObjectColumn(name='s' + self.colors_name[i],
                             cell_color=getColor(self.colors[i])))
        data = ColorTable()
        self.color_table.append(data)
        self.color_columns = columns
        self.color_table.remove(data)

        data = ColorTable()
        data.name = self.pcolor
        for i in range(len(self.colors_name)):
            exec('data.s' + self.colors_name[i] + '="' + self.colors_name[i] +
                 '"')
        self.color_table.append(data)

        data = ColorTable()
        data.name = "Color"
        for i in range(len(self.colors_name)):
            exec('data.s' + self.colors_name[i] + '=""')
        self.color_table.append(data)

    def _updateSizeTable(self):
        del (self.size_table)

        columns = [ObjectColumn(name='name')]
        for i in range(len(self.sizes_name)):
            columns.append(ObjectColumn(name='s' + self.sizes_name[i]))
        data = SizeTable()
        self.size_table.append(data)
        self.size_columns = columns
        self.size_table.remove(data)

        data = SizeTable()
        data.name = self.psize
        for i in range(len(self.sizes_name)):
            exec('data.s' + self.sizes_name[i] + '="' + self.sizes_name[i] +
                 '"')
        self.size_table.append(data)

        data = SizeTable()
        data.name = "Size"
        for i in range(len(self.sizes_name)):
            exec('data.s' + self.sizes_name[i] + '="' + sizes[i] + '"')
        self.size_table.append(data)

    def _updateTable(self):
        numPcaScores = len(self.PCAData) - 1
        pca_vars = []
        sumVar = 0.0
        sumPercent = 0.0
        del (self.table)

        data = MyData()
        data.name = 'PCA Index'
        for i in range(numPcaScores):
            exec('data.PCA' + str(i) + '=' + str(i))
        self.table.append(data)

        data = MyData()
        data.name = 'Percentage Power'
        for i in range(numPcaScores):
            pca = self._getPCAArray(i)
            temp = var(pca)
            pca_vars.append(temp)
            sumVar = sumVar + temp
        for i in range(numPcaScores):
            percent = 100 * pca_vars[i] / sumVar
            exec('data.PCA' + str(i) + ('=%0.2f' % percent))
        self.table.append(data)

        data = MyData()
        data.name = 'Cumulative Percentage Power'
        for i in range(numPcaScores):
            percent = 100 * pca_vars[i] / sumVar
            sumPercent = sumPercent + percent
            exec('data.PCA' + str(i) + ('=%0.2f' % sumPercent))
        self.table.append(data)

    def _update_Both_graph(self):
        self.activeScore = self.active_scores_combobox
        self._create_2D_plot()
        self._create_1D1_plot()
        self._updateTable()
Esempio n. 23
0
    def __init__(self,
                 pipeline,
                 name,
                 inputs,
                 outputs,
                 make_optional=(),
                 output_types=None):
        """ Generate a Switch Node

        Warnings
        --------
        The input plug names are built according to the following rule:
        <input_name>_switch_<output_name>

        Parameters
        ----------
        pipeline: Pipeline (mandatory)
            the pipeline object where the node is added
        name: str (mandatory)
            the switch node name
        inputs: list (mandatory)
            a list of options
        outputs: list (mandatory)
            a list of output parameters
        make_optional: sequence (optional)
            list of optional outputs.
            These outputs will be made optional in the switch output. By
            default they are mandatory.
        output_types: sequence of traits (optional)
            If given, this sequence sould have the same size as outputs. It
            will specify each switch output parameter type (as a standard
            trait). Input parameters for each input block will also have this
            type.
        """
        # if the user pass a simple element, create a list and add this
        # element
        #super(Node, self).__init__()
        self.__block_output_propagation = False
        if not isinstance(outputs, list):
            outputs = [
                outputs,
            ]
        if output_types is not None:
            if not isinstance(output_types, list) \
                    and not isinstance(output_types, tuple):
                raise ValueError(
                    'output_types parameter should be a list or tuple')
            if len(output_types) != len(outputs):
                raise ValueError('output_types should have the same number of '
                                 'elements as outputs')
        else:
            output_types = [Any(Undefined)] * len(outputs)

        # check consistency
        if not isinstance(inputs, list) or not isinstance(outputs, list):
            raise Exception("The Switch node input and output parameters "
                            "are inconsistent: expect list, "
                            "got {0}, {1}".format(type(inputs), type(outputs)))

        # private copy of outputs and inputs
        self._outputs = outputs
        self._switch_values = inputs

        # format inputs and outputs to inherit from Node class
        flat_inputs = []
        for switch_name in inputs:
            flat_inputs.extend([
                "{0}_switch_{1}".format(switch_name, plug_name)
                for plug_name in outputs
            ])
        node_inputs = ([
            dict(name="switch"),
        ] + [dict(name=i, optional=True) for i in flat_inputs])
        node_outputs = [
            dict(name=i, optional=(i in make_optional)) for i in outputs
        ]
        # inherit from Node class
        super(Switch, self).__init__(pipeline, name, node_inputs, node_outputs)
        for node in node_inputs[1:]:
            plug = self.plugs[node["name"]]
            plug.enabled = False

        # add switch enum trait to select the process
        self.add_trait("switch", Enum(output=False, *inputs))

        # add a trait for each input and each output
        input_types = output_types * len(inputs)
        for i, trait in zip(flat_inputs, input_types):
            self.add_trait(i, trait)
            self.trait(i).output = False
            self.trait(i).optional = self.plugs[i].optional
        for i, trait in zip(outputs, output_types):
            self.add_trait(i, trait)
            self.trait(i).output = True
            self.trait(i).optional = self.plugs[i].optional

        # activate the switch first Process
        self._switch_changed(self._switch_values[0], self._switch_values[0])
Esempio n. 24
0
class Tool(HasTraits):
    type = Enum("pan", "zoom", "regression")
    button = Enum(None, "left", "right")
Esempio n. 25
0
class Dialog(MDialog, Window):
    """ The toolkit specific implementation of a Dialog.  See the IDialog
    interface for the API documentation.
    """

    # 'IDialog' interface -------------------------------------------------#

    cancel_label = Str()

    help_id = Str()

    help_label = Str()

    ok_label = Str()

    resizeable = Bool(True)

    return_code = Int(OK)

    style = Enum("modal", "nonmodal")

    # 'IWindow' interface -------------------------------------------------#

    title = Str("Dialog")

    # ------------------------------------------------------------------------
    # Protected 'IDialog' interface.
    # ------------------------------------------------------------------------

    def _create_buttons(self, parent):
        sizer = wx.StdDialogButtonSizer()

        # The 'OK' button.
        if self.ok_label:
            label = self.ok_label
        else:
            label = "OK"

        self._wx_ok = ok = wx.Button(parent, wx.ID_OK, label)
        ok.SetDefault()
        parent.Bind(wx.EVT_BUTTON, self._wx_on_ok, id=wx.ID_OK)
        sizer.AddButton(ok)

        # The 'Cancel' button.
        if self.cancel_label:
            label = self.cancel_label
        else:
            label = "Cancel"

        self._wx_cancel = cancel = wx.Button(parent, wx.ID_CANCEL, label)
        parent.Bind(wx.EVT_BUTTON, self._wx_on_cancel, id=wx.ID_CANCEL)
        sizer.AddButton(cancel)

        # The 'Help' button.
        if len(self.help_id) > 0:
            if self.help_label:
                label = self.help_label
            else:
                label = "Help"

            help = wx.Button(parent, wx.ID_HELP, label)
            parent.Bind(wx.EVT_BUTTON, self._wx_on_help, id=wx.ID_HELP)
            sizer.AddButton(help)

        sizer.Realize()
        return sizer

    def _create_contents(self, parent):
        sizer = wx.BoxSizer(wx.VERTICAL)
        parent.SetSizer(sizer)
        parent.SetAutoLayout(True)

        # The 'guts' of the dialog.
        dialog_area = self._create_dialog_area(parent)
        sizer.Add(dialog_area, 1, wx.EXPAND | wx.ALL, 5)

        # The buttons.
        buttons = self._create_buttons(parent)
        sizer.Add(buttons, 0, wx.ALIGN_RIGHT | wx.ALL, 5)

        # Resize the dialog to match the sizer's minimal size.
        if self.size != (-1, -1):
            parent.SetSize(self.size)
        else:
            sizer.Fit(parent)

        parent.CentreOnParent()

    def _create_dialog_area(self, parent):
        panel = wx.Panel(parent, -1)
        panel.SetBackgroundColour("red")
        panel.SetSize((100, 200))

        return panel

    def _show_modal(self):
        if sys.platform == "darwin":
            # Calling Show(False) is needed on the Mac for the modal dialog
            # to show up at all.
            self.control.Show(False)
        return _RESULT_MAP[self.control.ShowModal()]

    # ------------------------------------------------------------------------
    # Protected 'IWidget' interface.
    # ------------------------------------------------------------------------

    def _create_control(self, parent):
        style = wx.DEFAULT_DIALOG_STYLE | wx.CLIP_CHILDREN

        if self.resizeable:
            style |= wx.RESIZE_BORDER

        return wx.Dialog(
            parent, -1, self.title, self.position, self.size, style
        )

    # wx event handlers ----------------------------------------------------

    def _wx_on_ok(self, event):
        """ Called when the 'OK' button is pressed. """

        self.return_code = OK

        # Let the default handler close the dialog appropriately.
        event.Skip()

    def _wx_on_cancel(self, event):
        """ Called when the 'Cancel' button is pressed. """

        self.return_code = CANCEL

        # Let the default handler close the dialog appropriately.
        event.Skip()

    def _wx_on_help(self, event):
        """ Called when the 'Help' button is pressed. """
        pass
Esempio n. 26
0
class Aesthetic(HasTraits):
    """ Acts as a data and visual configuration template which is
    passed in to gpplot()
    """

    x = Trait(None, Str)
    y = Trait(None, Str)

    color = Trait(None, Str)
    fill = Trait(None, Str)

    # Shape can be:
    #   an integer in 0..25 for the various symbols
    #   a single character to use as the literal symbol
    #   a "." to draw the smallest rectangle
    #   None, to draw nothing
    # All symbols have a foreground color; symbols 19-25 also have bgcolor
    shape = Trait(None, Str, Int)

    # Size in millimeters
    size = Trait(None, Int)

    # Justification can also be a number between 0..1, giving position
    # within the string
    #justification = Enum("left", "right", "center", "bottom", "top")

    line_type = Trait(
        None,
        Enum("solid", "dashed", "dotted", "dotdash", "longdash", "twodash",
             "blank"))
    line_weight = Trait(None, Int)

    #binwidth = Float()
    #label = Str()
    #ymin = Float()
    #ymax = Float()
    #group = Str()

    def __init__(self, x=None, y=None, **kwtraits):
        super(Aesthetic, self).__init__(**kwtraits)
        if x:
            self.x = x
        if y:
            self.y = y

    def __add__(self, aes):
        """ Returns a new Aesthetic class that represents the merger of this
        instance with another one.  This instance (LHS) takes lower 
        precedence, and its values are masked by ones in RHS argument.
        """
        newaes = self.clone_traits()
        for trait_name in set(aes.trait_names()) - set(
            ("trait_modified", "trait_added")):
            if getattr(aes, trait_name) is not None:
                setattr(self, trait_name, getattr(aes, trait_name))
        return newaes

    def merge_defaults(self):
        """ Fills in all trait values that are None with values from the
        DefaultStyle dictionary.
        """
        for trait_name in set(self.trait_names()) - set(
            ("trait_modified", "trait_added")):
            if getattr(self, trait_name) is None:
                setattr(self, trait_name, DefaultStyle[trait_name])
Esempio n. 27
0
class MayaviGrid(HasTraits):
    ''' This class is used to plot the data in a vlsv file as a mayavi grid The following will bring up a new window and plot the grid in the vlsv file:

   .. code-block:: python

      grid = pt.grid.MayaviGrid(vlsvReader=f, variable="rho", operator='pass', threaded=False)

   Once you have the window open you can use the picker tool in the right-upper corner and use various point-click tools for analyzing data.

   Picker options:
   
   **None** Does nothing upon clicking somewhere in the grid
   
   **Velocity_space** Plots the velocity space at a specific position upon clicking somewhere in the grid Note: If the vlsv file does not have the velocity space at the position where you are clicking, this will not work
   
   **Velocity_space_iso_surface** Plots the velocity space at a specific position upon clicking somewhere in the grid in iso-surface plotting style Note: If the vlsv file does not have the velocity space at the position where you are clicking, this will not work
   
   **Velocity_space_nearest_cellid** Plots the velocity space of the closest cell id to the picking point Note: If the vlsv file does not have velocity space saved at all, this will not work
   
   **Velocity_space_nearest_cellid_iso_surface** Plots the velocity space of the closest cell id to the picking point in iso-surface plotting style Note: If the vlsv file does not have velocity space saved at all, this will not work
   
   **Pitch_angle** Plots the pitch angle distribution at the clicking position Note: If the vlsv file does not have the velocity space at the position where you are clicking, this will not work
   
   **Gyrophase_angle** Plots the gyrophase angle distribution at the clicking position Note: If the vlsv file does not have the velocity space at the position where you are clicking, this will not work
   
   **Cut_through** Is used to plot or save the cut-through between two clicking points. This option requires you to use the args section at top-left. To use the args section to plot variables you must write for example: **plot rho B,x E,y** Upon clicking at two points a new window would open with a cut-through plot of rho, x-component of B and y-component of E Alternatively, you can save the cut-through to a variable in the MayaviGrid class by typing instead: **rho B,x E,y** and then going to the terminal and typing
   
   .. code-block:: python

      cut_through_data = grid.cut_through
      print cut_through_data

   '''
    picker = Enum('None', 'Velocity_space', "Velocity_space_nearest_cellid",
                  'Velocity_space_iso_surface',
                  'Velocity_space_nearest_cellid_iso_surface', "Pitch_angle",
                  "Gyrophase_angle", "Cut_through")

    args = ""

    variable_plotted = ""

    labels = []

    cut_through = []

    plot = []

    scene = Instance(MlabSceneModel, ())

    engine_view = Instance(EngineView)

    current_selection = Property

    dataset = []

    # Define the view:
    view = View(
        HGroup(
            Item('scene',
                 editor=SceneEditor(scene_class=MayaviScene),
                 height=250,
                 width=300,
                 show_label=False,
                 resizable=True),
            Group(
                #'cell_pick',
                'picker',
                'args',
                show_labels=True),
        ),
        resizable=True,
    )

    def __init__(self,
                 vlsvReader,
                 variable,
                 operator="pass",
                 threaded=True,
                 **traits):
        ''' Initializes the class and loads the mayavi grid

          :param vlsvReader:        Some vlsv reader with a file open
          :type vlsvReader:         :class:`vlsvfile.VlsvReader`
          :param variable:          Name of the variable
          :param operator:          Operator for the variable
          :param threaded:          Boolean value for using threads or not using threads to draw the grid (threads enable interactive mode)
      '''
        HasTraits.__init__(self, **traits)
        self.__vlsvReader = vlsvReader
        self.engine_view = EngineView(engine=self.scene.engine)
        self.__engine = self.scene.engine
        self.__picker = []
        self.__mins = []
        self.__maxs = []
        self.__cells = []
        self.__last_pick = []
        self.__structured_figures = []
        self.__unstructured_figures = []
        self.__thread = []
        self.__load_grid(variable=variable,
                         operator=operator,
                         threaded=threaded)
        self.variable_plotted = variable

    def __module_manager(self):
        import mayavi.core.module_manager as MM
        module_manager = self.scene.mayavi_scene
        # Find the module manager:
        while (True):
            module_manager = module_manager.children[0]
            if type(module_manager) == type(MM.ModuleManager()):
                break
        return module_manager

    def __add_label(self, cellid):
        # Add dataset:
        from mayavi.modules.labels import Labels
        indices = self.__vlsvReader.get_cell_indices(cellid)
        self.labels = Labels()
        self.labels.number_of_labels = 1
        self.labels.mask.filter.random_mode = False
        self.labels.mask.filter.offset = int(
            indices[0] + (self.__cells[0] + 1) * indices[1] +
            (self.__cells[0] + 1) * (self.__cells[1] + 1) * (indices[2] + 1))
        module_manager = self.__module_manager()
        # Add the label / marker:
        self.__engine.add_filter(self.labels, module_manager)
        #module_manager = engine.scenes[0].children[0].children[0]
        #engine.add_filter(labels1, module_manager)
        #self.labels = self.scene.mlab.pipeline.labels( self.dataset )

    def __add_normal_labels(self, point1, point2):
        # Get spatial grid sizes:
        xcells = (int)(self.__vlsvReader.read_parameter("xcells_ini"))
        ycells = (int)(self.__vlsvReader.read_parameter("ycells_ini"))
        zcells = (int)(self.__vlsvReader.read_parameter("zcells_ini"))

        xmin = self.__vlsvReader.read_parameter("xmin")
        ymin = self.__vlsvReader.read_parameter("ymin")
        zmin = self.__vlsvReader.read_parameter("zmin")
        xmax = self.__vlsvReader.read_parameter("xmax")
        ymax = self.__vlsvReader.read_parameter("ymax")
        zmax = self.__vlsvReader.read_parameter("zmax")

        dx = (xmax - xmin) / (float)(xcells)
        dy = (ymax - ymin) / (float)(ycells)
        dz = (zmax - zmin) / (float)(zcells)

        # Get normal vector from point2 and point1
        point1 = np.array(point1)
        point2 = np.array(point2)
        normal_vector = (point2 - point1) / np.linalg.norm(point2 - point1)
        normal_vector = np.dot(rotation_matrix_2d(
            -0.5 * np.pi), (point2 - point1)) / np.linalg.norm(point2 - point1)
        normal_vector = normal_vector * np.array([1, 1, 0])
        point1_shifted = point1 + 0.5 * (point2 -
                                         point1) - normal_vector * (8 * dx)
        point2_shifted = point1 + 0.5 * (point2 -
                                         point1) + normal_vector * (8 * dx)
        point1 = np.array(point1_shifted)
        point2 = np.array(point2_shifted)

        cellid1 = self.__vlsvReader.get_cellid(point1)
        cellid2 = self.__vlsvReader.get_cellid(point2)

        # Input label:
        self.__add_label(cellid1)
        self.__add_label(cellid2)

    def __load_grid(self, variable, operator="pass", threaded=True):
        ''' Creates a grid and inputs scalar variables from a vlsv file
          :param variable:        Name of the variable to plot
          :param operator:        Operator for the variable
          :param threaded:        Boolean value for using threads or not using threads to draw the grid (threads enable interactive mode)
      '''
        # Get the cell params:
        mins = np.array([
            self.__vlsvReader.read_parameter("xmin"),
            self.__vlsvReader.read_parameter("ymin"),
            self.__vlsvReader.read_parameter("zmin")
        ])
        cells = np.array([
            self.__vlsvReader.read_parameter("xcells_ini"),
            self.__vlsvReader.read_parameter("ycells_ini"),
            self.__vlsvReader.read_parameter("zcells_ini")
        ])
        maxs = np.array([
            self.__vlsvReader.read_parameter("xmax"),
            self.__vlsvReader.read_parameter("ymax"),
            self.__vlsvReader.read_parameter("zmax")
        ])
        # Get the variables:
        index_for_cellid_dict = self.__vlsvReader.get_cellid_locations()
        variable_array = self.__vlsvReader.read_variable(name=variable,
                                                         operator=operator)
        # Sort the dictionary by cell id
        import operator as oper
        sorted_index_for_cellid_dict = sorted(
            index_for_cellid_dict.iteritems(), key=oper.itemgetter(0))
        # Add the variable values:
        variable_array_sorted = []
        for i in sorted_index_for_cellid_dict:
            variable_array_sorted.append(variable_array[i[1]])
        # Store the mins and maxs:
        self.__mins = mins
        self.__maxs = maxs
        self.__cells = cells
        # Draw the grid:
        if threaded == True:
            thread = threading.Thread(target=self.__generate_grid,
                                      args=(mins, maxs, cells,
                                            variable_array_sorted, variable))
            thread.start()
        else:
            self.__generate_grid(mins=mins,
                                 maxs=maxs,
                                 cells=cells,
                                 datas=variable_array_sorted,
                                 names=variable)

    def __picker_callback(self, picker):
        """ This gets called when clicking on a cell
      """
        if (self.picker != "Cut_through"):
            # Make sure the last pick is null (used in cut_through)
            self.__last_pick = []

        coordinates = picker.pick_position
        coordinates = np.array(
            [coordinates[0], coordinates[1], coordinates[2]])
        # For numerical inaccuracy
        epsilon = 80
        # Check for numerical inaccuracy
        for i in xrange(3):
            if (coordinates[i] < self.__mins[i]) and (coordinates[i] + epsilon
                                                      > self.__mins[i]):
                # Correct the numberical inaccuracy
                coordinates[i] = self.__mins[i] + 1
            if (coordinates[i] > self.__maxs[i]) and (coordinates[i] - epsilon
                                                      < self.__maxs[i]):
                # Correct the values
                coordinates[i] = self.__maxs[i] - 1
        print "COORDINATES:" + str(coordinates)
        cellid = self.__vlsvReader.get_cellid(coordinates)
        print "CELL ID: " + str(cellid)
        # Check for an invalid cell id
        if cellid == 0:
            print "Invalid cell id"
            return

        if (self.picker == "Velocity_space"):
            # Set label to give out the location of the cell:
            self.__add_label(cellid)
            # Generate velocity space
            self.__generate_velocity_grid(cellid)
        elif (self.picker == "Velocity_space_nearest_cellid"):
            # Find the nearest cell id with distribution:
            # Read cell ids with velocity distribution in:
            cell_candidates = self.__vlsvReader.read("SpatialGrid",
                                                     "CELLSWITHBLOCKS")
            # Read in the coordinates of the cells:
            cell_candidate_coordinates = [
                self.__vlsvReader.get_cell_coordinates(cell_candidate)
                for cell_candidate in cell_candidates
            ]
            # Read in the cell's coordinates:
            pick_cell_coordinates = self.__vlsvReader.get_cell_coordinates(
                cellid)
            # Find the nearest:
            from operator import itemgetter
            norms = np.sum(
                (cell_candidate_coordinates - pick_cell_coordinates)**2,
                axis=-1)**(1. / 2)
            norm, i = min((norm, idx) for (idx, norm) in enumerate(norms))
            # Get the cell id:
            cellid = cell_candidates[i]
            # Set label to give out the location of the cell:
            self.__add_label(cellid)
            # Generate velocity grid
            self.__generate_velocity_grid(cellid)
        elif (self.picker == "Velocity_space_iso_surface"):
            # Set label to give out the location of the cell:
            self.__add_label(cellid)
            self.__generate_velocity_grid(cellid, True)
        elif (self.picker == "Velocity_space_nearest_cellid_iso_surface"):
            # Find the nearest cell id with distribution:
            # Read cell ids with velocity distribution in:
            cell_candidates = self.__vlsvReader.read("SpatialGrid",
                                                     "CELLSWITHBLOCKS")
            # Read in the coordinates of the cells:
            cell_candidate_coordinates = [
                self.__vlsvReader.get_cell_coordinates(cell_candidate)
                for cell_candidate in cell_candidates
            ]
            # Read in the cell's coordinates:
            pick_cell_coordinates = self.__vlsvReader.get_cell_coordinates(
                cellid)
            # Find the nearest:
            from operator import itemgetter
            norms = np.sum(
                (cell_candidate_coordinates - pick_cell_coordinates)**2,
                axis=-1)**(1. / 2)
            norm, i = min((norm, idx) for (idx, norm) in enumerate(norms))
            # Get the cell id:
            cellid = cell_candidates[i]
            # Set label to give out the location of the cell:
            self.__add_label(cellid)
            # Generate velocity grid
            self.__generate_velocity_grid(cellid, True)
        elif (self.picker == "Pitch_angle"):
            # Set label to give out the location of the cell:
            self.__add_label(cellid)
            # Plot pitch angle distribution:
            from pitchangle import pitch_angles
            result = pitch_angles(vlsvReader=self.__vlsvReader,
                                  cellid=cellid,
                                  cosine=True,
                                  plasmaframe=True)
            # plot:
            pl.hist(result[0].data, weights=result[1].data, bins=50, log=False)
            pl.show()
        elif (self.picker == "Gyrophase_angle"):
            # Plot gyrophase angle distribution:
            from gyrophaseangle import gyrophase_angles_from_file
            result = gyrophase_angles_from_file(vlsvReader=self.__vlsvReader,
                                                cellid=cellid)
            # plot:
            pl.hist(result[0].data,
                    weights=result[1].data,
                    bins=36,
                    range=[-180.0, 180.0],
                    log=True,
                    normed=1)
            pl.show()
        elif (self.picker == "Cut_through"):
            if len(self.__last_pick) == 3:
                from cutthrough import cut_through
                # Get a cut-through
                self.cut_through = cut_through(self.__vlsvReader,
                                               point1=self.__last_pick,
                                               point2=coordinates)
                # Get cell ids and distances separately
                cellids = self.cut_through[0].data
                distances = self.cut_through[1]
                # Get any arguments from the user:
                args = self.args.split()
                if len(args) == 0:
                    #Do nothing
                    print "Bad args"
                    self.__last_pick = []
                    return
                plotCut = False
                plotRankine = False
                # Optimize file read:
                self.__vlsvReader.optimize_open_file()
                variables = []
                # Save variables
                for i in xrange(len(args)):
                    # Check if the user has given the plot argument
                    if args[i] == "plot":
                        plotCut = True
                    elif args[i] == "rankine":
                        # set labels:
                        self.__add_normal_labels(point1=self.__last_pick,
                                                 point2=coordinates)
                        fig = plot_rankine(self.__vlsvReader,
                                           point1=self.__last_pick,
                                           point2=coordinates)
                        #pl.show()
                        self.__last_pick = []
                        self.plot = fig
                        return
                    else:
                        if args[i].find(",") != -1:
                            _variable = args[i].split(',')[0]
                            _operator = args[i].split(',')[1]
                            variable_info = self.__vlsvReader.read_variable_info(
                                name=_variable,
                                cellids=cellids,
                                operator=_operator)
                            variables.append(variable_info)
                            self.cut_through.append(variable_info)
                        else:
                            variable_info = self.__vlsvReader.read_variable_info(
                                name=args[i], cellids=cellids)
                            variables.append(variable_info)
                            self.cut_through.append(variable_info)
                if plotCut == True:
                    # Set label to give out the location of the cell:
                    self.__add_label(cellids[0])
                    self.__add_label(cellids[len(cellids) - 1])
                    if plotRankine == True:
                        # Plot Rankine-Hugoniot jump conditions:
                        normal_vector = (coordinates - self.__last_pick
                                         ) / np.linalg.norm(coordinates -
                                                            self.__last_pick)
                        # Read V, B, T and rho
                        V = self.__vlsvReader.read_variable("v",
                                                            cellids=cellids[0])
                        B = self.__vlsvReader.read_variable("B",
                                                            cellids=cellids[0])
                        T = self.__vlsvReader.read_variable("Temperature",
                                                            cellids=cellids[0])
                        rho = self.__vlsvReader.read_variable(
                            "rho", cellids=cellids[0])
                        # Get parallel and perpendicular components:
                        Vx = np.dot(V, normal_vector)
                        Vy = np.linalg.norm(V - Vx * normal_vector)
                        Bx = np.dot(B, normal_vector)
                        By = np.linalg.norm(B - Bx * normal_vector)
                        # Calculate jump conditions
                        conditions = oblique_shock(Vx, Vy, Bx, By, T, rho)
                        rankine_variables = []
                        for i in xrange(len(get_data(distances))):
                            if i < len(get_data(distances)) * 0.5:
                                rankine_variables.append(rho)
                            else:
                                rankine_variables.append(conditions[5])
                        variables.append(rankine_variables)
                    from plot import plot_multiple_variables
                    fig = plot_multiple_variables(
                        [distances for i in xrange(len(args) - 1)],
                        variables,
                        figure=[])
                    pl.show()
                # Close the optimized file read:
                self.__vlsvReader.optimize_close_file()
                # Read in the necessary variables:
                self.__last_pick = []
            else:
                self.__last_pick = coordinates

    def __generate_grid(self, mins, maxs, cells, datas, names):
        ''' Generates a grid from given data
          :param mins:           An array of minimum coordinates for the grid for ex. [-100, 0, 0]
          :param maxs:           An array of maximum coordinates for the grid for ex. [-100, 0, 0]
          :param cells:          An array of number of cells in x, y, z direction
          :param datas:          Scalar data for the grid e.g. array([ cell1Rho, cell2Rho, cell3Rho, cell4Rho, .., cellNRho ])
          :param names:          Name for the scalar data
      '''
        # Create nodes
        x, y, z = mgrid[mins[0]:maxs[0]:(cells[0] + 1) * complex(0, 1),
                        mins[1]:maxs[1]:(cells[1] + 1) * complex(0, 1),
                        mins[2]:maxs[2]:(cells[2] + 1) * complex(0, 1)]

        # Create points for the nodes:
        pts = empty(z.shape + (3, ), dtype=float)
        pts[..., 0] = x
        pts[..., 1] = y
        pts[..., 2] = z

        # Input scalars
        scalars = np.array(datas)

        # We reorder the points, scalars and vectors so this is as per VTK's
        # requirement of x first, y next and z last.
        pts = pts.transpose(2, 1, 0, 3).copy()
        pts.shape = pts.size / 3, 3
        scalars = scalars.T.copy()

        # Create the dataset.
        sg = tvtk.StructuredGrid(dimensions=x.shape, points=pts)
        sg.cell_data.scalars = ravel(scalars.copy())
        sg.cell_data.scalars.name = names

        # Visualize the data
        d = self.scene.mlab.pipeline.add_dataset(sg)
        iso = self.scene.mlab.pipeline.surface(d)

        # Add labels:
        #      from mayavi.modules.labels import Labels
        #      testlabels = self.scene.mlab.pipeline.labels(d)

        self.dataset = d

        # Configure traits
        self.configure_traits()

        # Note: This is not working properly -- it seemingly works out at first but it eventually causes segmentation faults in some places
        #self.__thread = threading.Thread(target=self.configure_traits, args=())
        #self.__thread.start()

    def __generate_velocity_grid(self, cellid, iso_surface=False):
        '''Generates a velocity grid from a given spatial cell id
         :param cellid:           The spatial cell's ID
         :param iso_surface:      If true, plots the iso surface
      '''
        # Create nodes
        # Get velocity blocks and avgs:
        blocksAndAvgs = self.__vlsvReader.read_blocks(cellid)
        if len(blocksAndAvgs) == 0:
            print "CELL " + str(cellid) + " HAS NO VELOCITY BLOCK"
            return False
        # Create a new scene
        self.__engine.new_scene()
        mayavi.mlab.set_engine(self.__engine)  #CONTINUE
        # Create a new figure
        figure = mayavi.mlab.gcf(engine=self.__engine)
        figure.scene.disable_render = True
        blocks = blocksAndAvgs[0]
        avgs = blocksAndAvgs[1]
        # Get nodes:
        nodesAndKeys = self.__vlsvReader.construct_velocity_cell_nodes(blocks)
        # Create an unstructured grid:
        points = nodesAndKeys[0]
        tets = nodesAndKeys[1]
        tet_type = tvtk.Voxel().cell_type  #VTK_VOXEL

        ug = tvtk.UnstructuredGrid(points=points)
        # Set up the cells
        ug.set_cells(tet_type, tets)
        # Input data
        values = np.ravel(avgs)
        ug.cell_data.scalars = values
        ug.cell_data.scalars.name = 'avgs'

        # Plot B if possible:
        # Read B vector and plot it:
        if self.__vlsvReader.check_variable("B") == True:
            B = self.__vlsvReader.read_variable(name="B", cellids=cellid)
        elif self.__vlsvReader.check_variable("B_vol") == True:
            B = self.__vlsvReader.read_variable(name="B_vol", cellids=cellid)
        else:
            B = self.__vlsvReader.read_variable(
                name="background_B",
                cellids=cellid) + self.__vlsvReader.read_variable(
                    name="perturbed_B", cellids=cellid)

        points2 = np.array([[0, 0, 0]])
        ug2 = tvtk.UnstructuredGrid(points=points2)
        ug2.point_data.vectors = [B / np.linalg.norm(B)]
        ug2.point_data.vectors.name = 'B_vector'
        #src2 = VTKDataSource(data = ug2)
        d2 = mayavi.mlab.pipeline.add_dataset(ug2)
        #mayavi.mlab.add_module(Vectors())
        vec = mayavi.mlab.pipeline.vectors(d2)
        vec.glyph.mask_input_points = True
        vec.glyph.glyph.scale_factor = 1e6
        vec.glyph.glyph_source.glyph_source.center = [0, 0, 0]

        # Visualize
        d = mayavi.mlab.pipeline.add_dataset(ug)
        if iso_surface == False:
            iso = mayavi.mlab.pipeline.surface(d)
        else:
            ptdata = mayavi.mlab.pipeline.cell_to_point_data(d)
            iso = mayavi.mlab.pipeline.iso_surface(
                ptdata, contours=[1e-15, 1e-14, 1e-12], opacity=0.3)
        figure.scene.disable_render = False
        self.__unstructured_figures.append(figure)
        # Name the figure
        figure.name = str(cellid) + ", " + self.variable_plotted + " = " + str(
            self.__vlsvReader.read_variable(self.variable_plotted,
                                            cellids=cellid))

        from mayavi.modules.axes import Axes
        axes = Axes()
        axes.name = 'Axes'
        axes.axes.fly_mode = 'none'
        axes.axes.number_of_labels = 8
        axes.axes.font_factor = 0.5
        #module_manager = self.__module_manager()
        # Add the label / marker:
        self.__engine.add_filter(axes)
        from mayavi.modules.outline import Outline
        outline = Outline()
        outline.name = 'Outline'
        self.__engine.add_filter(outline)
        return True

    def generate_diff_grid(self, cellid1, cellid2):
        ''' Generates a diff grid of given cell ids (shows avgs diff)

          :param cellid1:          The first cell id
          :param cellid2:          The second cell id

          .. code-block:: python

             # Example:
             grid.generate_diff_grid( 29219, 2910 )

          .. note:: If the cell id does not have a certain velocity cell, it is assumed that the avgs value of that cell is 0

      '''
        # Create nodes
        # Get velocity blocks and avgs (of cellid 1)
        blocksAndAvgs1 = self.__vlsvReader.read_blocks(cellid1)
        if len(blocksAndAvgs1) == 0:
            print "CELL " + str(cellid1) + " HAS NO VELOCITY BLOCK"
            return False
        blocks1 = blocksAndAvgs1[0]
        avgs1 = blocksAndAvgs1[1]

        # Get velocity blocks and avgs (of cellid 2)
        blocksAndAvgs2 = self.__vlsvReader.read_blocks(cellid2)
        if len(blocksAndAvgs2) == 0:
            print "CELL " + str(cellid2) + " HAS NO VELOCITY BLOCK"
            return False
        blocks2 = blocksAndAvgs2[0]
        avgs2 = blocksAndAvgs2[1]
        print len(avgs2)
        print len(blocks2)

        # Compare blocks and create a new avgs array values:
        avgs_same = []
        avgs_cellid1 = []
        avgs_cellid2 = []
        blocks_same = []
        blocks_cellid1 = []
        blocks_cellid2 = []
        print np.shape(avgs1[0])
        for i in xrange(len(blocks1)):
            b = blocks1[i]
            # Get index of block
            i2 = np.where(blocks2 == b)[0]
            if len(i2) != 0:
                # Fetch the block:
                #print avgs1[64*i:64*(i+1)]
                #print avgs2[64*i2[0]:64*(i2[0]+1)]
                avgs_same.append(avgs1[i:(i + 1)] - avgs2[i2[0]:(i2[0] + 1)])
                blocks_same.append(b)
            else:
                avgs_cellid1.append(avgs1[i:(i + 1)])
                blocks_cellid1.append(b)
        for i in xrange(len(blocks2)):
            b = blocks2[i]
            if (b in blocks1) == False:
                avgs_cellid2.append(avgs2[i:(i + 1)])
                blocks_cellid2.append(b)
        # Make a list for the avgs etc
        avgs = np.zeros(
            64 * (len(avgs_same) + len(avgs_cellid1) + len(avgs_cellid2)))
        #avgs = np.reshape(avgs, (len(avgs_same)+len(avgs_cellid1)+len(avgs_cellid2), 64))
        print np.shape(avgs_same)
        blocks = np.zeros(
            len(blocks_same) + len(blocks_cellid1) + len(blocks_cellid2))

        index = 0
        avgs[64 * index:64 * (index + len(blocks_same))] = np.ravel(
            np.array(avgs_same))
        blocks[index:index + len(blocks_same)] = np.array(blocks_same)

        index = index + len(blocks_same)
        avgs[64 * index:64 * (index + len(blocks_cellid1))] = np.ravel(
            np.array(avgs_cellid1))
        blocks[index:index + len(blocks_cellid1)] = np.array(blocks_cellid1)

        index = index + len(blocks_cellid1)
        avgs[64 * index:64 * (index + len(blocks_cellid2))] = np.ravel(
            np.array(avgs_cellid2))
        blocks[index:index + len(blocks_cellid2)] = np.array(blocks_cellid2)

        blocks = blocks.astype(int)

        # Get nodes:
        nodesAndKeys = self.__vlsvReader.construct_velocity_cell_nodes(blocks)

        # Create an unstructured grid:
        points = nodesAndKeys[0]
        tets = nodesAndKeys[1]

        # Create a new scene
        self.__engine.new_scene()
        mayavi.mlab.set_engine(self.__engine)  #CONTINUE
        # Create a new figure
        figure = mayavi.mlab.gcf(engine=self.__engine)
        figure.scene.disable_render = True
        tet_type = tvtk.Voxel().cell_type  #VTK_VOXEL

        ug = tvtk.UnstructuredGrid(points=points)
        #Thissetsupthecells.
        ug.set_cells(tet_type, tets)
        #Attributedata.
        values = np.ravel(avgs)
        ug.cell_data.scalars = values
        ug.cell_data.scalars.name = 'avgs'
        d = mayavi.mlab.pipeline.add_dataset(ug)
        iso = mayavi.mlab.pipeline.surface(d)
        figure.scene.disable_render = False
        self.__unstructured_figures.append(figure)
        # Name the figure
        figure.name = str(cellid1) + " " + str(cellid2)
        mayavi.mlab.show()
        return True

    def __do_nothing(self, picker):
        return

    # Trait events:
    @on_trait_change('scene.activated')
    def set_mouse_click(self):
        # Temporary bug fix (MayaVi needs a dummy pick to be able to remove cells callbacks from picker.. )
        #self.figure.on_mouse_pick( self.__do_nothing, type='world'
        self.figure = self.scene.mlab.gcf()
        # Cell picker
        func = self.__picker_callback
        typeid = 'world'
        click = 'Left'
        picker = self.figure.on_mouse_pick(func, type='world')
        self.__picker = [func, typeid, click]
        #picker.tolerance = 0
        # Show legend bar
        manager = self.figure.children[0].children[0]
        manager.scalar_lut_manager.show_scalar_bar = True
        manager.scalar_lut_manager.show_legend = True
class ParameterTemplate(BaseTemplate):
    """BaseTemplate subclass to generate MCO Parameter options for
    SurfactantContributedUI"""

    # --------------------
    #  Regular Attributes
    # --------------------

    #: String representing MCOParameter subclass
    parameter_type = Enum('Fixed', 'Ranged', 'Listed')

    #: Name of Parameter
    name = Unicode()

    #: CUBA type of Parameter
    type = Unicode('CONCENTRATION')

    #: MCOParameter level trait
    value = Float(1.0)

    #: RangedMCOParameter lower_bound trait
    lower_bound = Float(0.5)

    #: RangedMCOParameter upper_bound trait
    upper_bound = Float(5.0)

    #: RangedMCOParameter n_samples trait
    n_samples = Int(10)

    #: ListedMCOParameter levels trait
    levels = ListFloat([0.5, 1.0, 3.0])

    # --------------------
    #      Properties
    # --------------------

    #: Factory ID for Workflow
    id = Property(Unicode, depends_on='plugin_id,type')

    # --------------------
    #        View
    # --------------------

    traits_view = View(
        Item('parameter_type'),
        Item("value", visible_when="parameter_type=='Fixed'"),
        Item("lower_bound", visible_when="parameter_type=='Ranged'"),
        Item("upper_bound", visible_when="parameter_type=='Ranged'"),
        Item("n_samples", visible_when="parameter_type=='Ranged'"),
        Item("levels",
             editor=ListEditor(style='simple'),
             visible_when="parameter_type=='Listed'"))

    # --------------------
    #      Listeners
    # --------------------

    def _get_id(self):
        return '.'.join(
            [self.plugin_id, 'parameter',
             self.parameter_type.lower()])

    # --------------------
    #    Public Methods
    # --------------------

    def create_template(self):
        template = {
            "id": self.id,
            "model_data": {
                "name": f"{self.name}_conc",
                "type": self.type
            }
        }

        if self.parameter_type == 'Fixed':
            template['model_data']["value"] = self.value
        elif self.parameter_type == 'Ranged':
            template['model_data']["lower_bound"] = self.lower_bound
            template['model_data']["upper_bound"] = self.upper_bound
            template['model_data']["n_samples"] = self.n_samples
        elif self.parameter_type == 'Listed':
            template['model_data']["levels"] = self.levels

        return template
Esempio n. 29
0
class BasePlotFrame(Container, PlotComponent):
    """
    Base class for plot frames.  Primarily defines the basic functionality
    of managing slots (sub-containers) within the plot frame.

    NOTE: PlotFrames are deprecated. There is no need to use them any more.
    This class will be removed some time in the near future.
    """

    #: A named list of places/positions/"slots" on the frame where PlotComponents
    #: can place themselves.  Subclasses must redefine this trait with the
    #: appropriate values.  Note that by default, __getattr__ treats these
    #: slot names as attributes on the class so they can be directly accessed.
    #: This is a class attribute.
    slot_names = ()

    #: Dimensions in which this frame can resize to fit its components.
    #: This is similar to the **resizable** trait on PlotComponent. Chaco
    #: plot frames use this attribute in preference to the Enable
    #: **auto_size** attribute (which is overridden to be False by default).
    fit_components = Enum("", "h", "v", "hv")

    #: Overrides the Enable auto_size trait (which will be deprecated in the future)
    auto_size = False

    draw_order = DEFAULT_DRAWING_ORDER

    def __init__(self, **kw):
        self._frame_slots = {}
        super(BasePlotFrame, self).__init__(**kw)
        return

    def add_to_slot(self, slot, component, stack="overlay"):
        """
        Adds a component to the named slot using the given stacking mode.
        The valid modes are: 'overlay', 'left', 'right', 'top', 'bottom'.
        """
        self.frame_slots[slot].add_plot_component(component, stack)
        return

    def set_slot(self, slotname, container):
        """
        Sets the named slot to use the given container. *container* can be None.
        """
        if slotname in self._frame_slots:
            old_container = self._frame_slots[slotname]
            Container.remove(self, old_container)
        if container is not None:
            self._frame_slots[slotname] = container
            Container.add(self, container)
        return

    def get_slot(self, slotname):
        """ Returns the container in the named slot. """
        return self._frame_slots.get(slotname, None)

    #------------------------------------------------------------------------
    # PlotComponent interface
    #------------------------------------------------------------------------

    def draw(self, gc, view_bounds=None, mode="normal"):
        """ Draws the plot frame.

        Frames are the topmost Chaco component that knows about layout, and they
        are the start of the layout pipeline.  When they are asked to draw,
        they can assume that their own size has been set properly and this in
        turn drives the layout of the contained components within the trame.
        """
        self.do_layout()

        #if gc.window and gc.window.is_sizing:
        if 0:
            with gc:
                gc.translate_ctm(*self.position)
                #TODO: We are ignoring Container...
                PlotComponent.draw(self, gc, view_bounds, "interactive")
        else:
            super(BasePlotFrame, self).draw(gc, view_bounds, mode)
        return

    def do_layout(self, size=None, force=False):
        """ Tells this frame to do layout at a given size.

        Overrides PlotComponent. If this frame needs to fit components in at
        least one dimension, then it checks whether any of them need to do
        layout; if so, the frame needs to do layout also.
        """
        if not self._layout_needed and not force and self.fit_components != "":
            for slot in self._frame_slots.values():
                if slot._layout_needed:
                    self._layout_needed = True
                    break
        return PlotComponent.do_layout(self, size, force)

    def _draw(self, *args, **kw):
        """ Draws the plot frame.

        Overrides PlotComponent and Container, explicitly calling the
        PlotComponent version of _draw().
        """
        PlotComponent._draw(self, *args, **kw)
        return

    def _dispatch_to_enable(self, event, suffix):
        """ Calls Enable-level event handlers.

        Overrides PlotComponent.
        """
        Container.dispatch(self, event, suffix)
        return

    #------------------------------------------------------------------------
    # Event handlers, properties
    #------------------------------------------------------------------------

    def _bounds_changed(self, old, new):
        if self.container is not None:
            self.container._component_bounds_changed(self)
        self._layout_needed = True
        return

    def _bounds_items_changed(self, event):
        return self._bounds_changed(None, self.bounds)

    #------------------------------------------------------------------------
    # Private methods
    #------------------------------------------------------------------------

    def __getattr__(self, name):
        if name in self.slot_names:
            return self._frame_slots[name]
        else:
            raise AttributeError("'%s' object has no attribute '%s'" % \
                                    (self.__class__.__name__, name))

    def __setattr__(self, name, value):
        if name in self.slot_names:
            self.set_slot(name, value)
        else:
            super(BasePlotFrame, self).__setattr__(name, value)
        return

    ### Persistence ###########################################################
#    _pickles = ("_frame_slots", "_components", "fit_components", "fit_window")

    def post_load(self, path=None):
        super(BasePlotFrame, self).post_load(path)
        for slot in self._frame_slots.values():
            slot.post_load(path)
        return
class ImplicitWidgets(Component):

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The widget type to use.
    widget_mode = Enum('Box', 'Sphere', 'Plane','ImplicitPlane',
                       desc='the implicit widget to use')

    # The actual poly data source widget.
    widget = Instance(tvtk.ThreeDWidget, record=True)

    update_mode = Trait('semi-interactive',
                        TraitMap({'interactive':'InteractionEvent',
                                  'semi-interactive': 'EndInteractionEvent'}),
                        desc='speed at which the data should be updated')

    implicit_function = Instance(tvtk.ImplicitFunction, allow_none=False)

    ########################################
    # Private traits.

    _first = Bool(True)
    _busy = Bool(False)
    _observer_id = Int(-1)

    # The actual widgets.
    _widget_dict = Dict(Str, Instance(tvtk.ThreeDWidget,
                        allow_none=False))

    # The actual implicit functions.
    _implicit_function_dict = Dict(Str, Instance(tvtk.ImplicitFunction,
                                   allow_none=False))

    ########################################
    # View related traits.
    ########################################
     # Create the UI for the traits.
    view = View(Group(Item(name='widget_mode'), Item(name='widget',
                            style='custom', resizable=True),
                            label='Widget Source', show_labels=False),
                            resizable=True)

    #####################################################################
    # `object` interface
    ######################################################################
    def __init__(self, **traits):
        # Call parent class' init.
        super(ImplicitWidgets, self).__init__(**traits)

        # Initialize the source to the default widget's instance from
        # the dictionary if needed.
        if 'widget_mode' not in traits:
            self._widget_mode_changed(self.widget_mode)

    ######################################################################
    # `Base` interface
    ######################################################################
    def __get_pure_state__(self):
        d = super(ImplicitWidgets, self).__get_pure_state__()
        for attr in ('_first', '_busy', '_observer_id', 'widget',
                     'implicit_function'):
            d.pop(attr, None)
        # The box widget requires a transformation matrix to be pickled.
        tfm = tvtk.Transform()
        w = self._widget_dict['Box']
        w.get_transform(tfm)
        d['matrix'] = pickle.dumps(tfm.matrix)
        return d

    def __set_pure_state__(self, state):
        # Pop the transformation matrix for the box widget.
        mat = state.pop('matrix')
        # Now set their state.
        set_state(self, state, first=['widget_mode'], ignore=['*'])
        # Set state of rest of the attributes ignoring the widget_mode.
        set_state(self, state, ignore=['widget_mode'])

        # Set the transformation for Box widget.
        tfm = tvtk.Transform()
        tfm.set_matrix(pickle.loads(mat))
        w = self._widget_dict['Box']
        w.set_transform(tfm)

        # Some widgets need some cajoling to get their setup right.
        w = self.widget
        # Set the input.
        if len(self.inputs) > 0:
            self.configure_input(w, self.inputs[0].outputs[0])
        w.update_traits()
        mode = self.widget_mode
        if mode == 'Plane':
            wd = state._widget_dict[mode]
            w.origin = wd.origin
            w.normal = wd.normal
            w.update_placement()
        self.update_implicit_function()
        # Set the widgets trait so that the widget is rendered if needed.
        self.widgets = [w]

    ######################################################################
    # `Component` interface
    ######################################################################
    def setup_pipeline(self):
        """Override this method so that it *creates* the tvtk
        pipeline.
        """
        # Setup the widgets.
        self.widgets = [self.widget]

    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.

        This method is invoked (automatically) when any of the inputs
        sends a `pipeline_changed` event.
        """
        if len(self.inputs) == 0:
            return
        inp = self.inputs[0].outputs[0]
        w = self.widget
        self.configure_input(w, inp)

        if self._first:
            w.place_widget()
            self._first = False

        # Set our output.
        if self.outputs != [inp]:
            self.outputs = [inp]
        else:
            self.data_changed = True

        self.pipeline_changed = True

    def update_data(self):
        """Override this method so that it flushes the vtk pipeline if
        that is necessary.

        This method is invoked (automatically) when any of the inputs
        sends a `data_changed` event.
        """
        self.data_changed = True

    ######################################################################
    # `SourceWidget` interface
    ######################################################################
    def update_implicit_function(self):
        """Update the implicit_function from the widget data.
        """
        dispatch = {'Sphere': 'get_sphere', 'Box': 'get_planes',
                    'Plane': 'get_plane', 'ImplicitPlane': 'get_plane'}
        method = getattr(self.widget, dispatch[self.widget_mode])
        method(self.implicit_function)

    ######################################################################
    # Non-public traits.
    ######################################################################
    def _widget_changed(self, old, value):
        if len(self.inputs) > 0:
            self.configure_input(value, self.inputs[0].outputs[0])
            value.place_widget()
        self.implicit_function = self._implicit_function_dict[self.widget_mode]

        if old is not None:
            self._connect(old, remove=True)
        self._connect(value, remove=False)
        self.widgets = [value]

    def _connect(self, value, remove=False):
        """Wire up event handlers or tear them down given a widget
        `value`.  If `remove` is True, then tear them down."""
        if remove and self._observer_id > 0:
                value.remove_observer(self._observer_id)
        else:
            self._observer_id = value.add_observer(self.update_mode_,
                                                   self._on_interaction_event)
        if isinstance(value, tvtk.PlaneWidget) or \
            isinstance(value, tvtk.ImplicitPlaneWidget):
            value.on_trait_change(self._on_alignment_set,
                                  'normal_to_x_axis', remove=remove)
            value.on_trait_change(self._on_alignment_set,
                                  'normal_to_y_axis', remove=remove)
            value.on_trait_change(self._on_alignment_set,
                                  'normal_to_z_axis', remove=remove)

        value.on_trait_change(self._on_widget_trait_changed,
                              remove=remove)
        value.on_trait_change(self.render, remove=remove)

    def _on_interaction_event(self, obj, event):
        self.update_implicit_function()

    def _update_mode_changed(self, old, new):
        w = self.widget
        if w is not None:
            w.remove_observer(self._observer_id)
            self._observer_id = w.add_observer(self.update_mode_,
                    self._on_interaction_event)

            w.on_trait_change(self.render)
            self.render()

    def _on_widget_trait_changed(self):
        if (not self._busy) and (self.update_mode != 'non-interactive'):
            self._busy = True
            self.implicit_function = self._implicit_function_dict[self.widget_mode]
            self.update_implicit_function()
            self.render()
            self._busy = False

    def _on_alignment_set(self):
        """Event handler when the widget's normal is reset (if
        applicable)."""
        w = self.widget
        w.place_widget()
        w.update_traits()
        self.render()

    def _scene_changed(self, old, new):
        super(ImplicitWidgets, self)._scene_changed(old, new)
        self._foreground_changed_for_scene(None, new.foreground)

    def _widget_mode_changed(self, value):
        """This method is invoked (automatically) when the `source`
        trait is changed.
        """
        self.widget = self._widget_dict[self.widget_mode]

    def __widget_dict_default(self):
        """Default value for source dict."""
        w = {'Box':tvtk.BoxWidget(place_factor = 0.9),
             'Sphere':tvtk.SphereWidget(place_factor = 0.9),
             'Plane':tvtk.PlaneWidget(place_factor = 0.9),
             'ImplicitPlane':
                tvtk.ImplicitPlaneWidget(place_factor=0.9,
                                         draw_plane=False)}
        return w

    def __implicit_function_dict_default(self):
        """Default value for source dict."""
        ip = {'Box':tvtk.Planes(),
              'Sphere':tvtk.Sphere(),
              'Plane':tvtk.Plane(),
              'ImplicitPlane': tvtk.Plane()}
        return ip
Esempio n. 31
0
class Enum(Variable):
    """A variable wrapper for an enumeration, which is a variable that
       can assume one value from a set of specified values.
       """

    def __init__(self, default_value=None, values=(), iotype=None,
                 aliases=(), desc=None, **metadata):

        assumed_default = False

        # Allow some variant constructors (no default, no index)
        if not values:
            if default_value is None:
                raise ValueError("Enum must contain at least one value.")
            else:
                values = default_value
                if isinstance(values, (tuple, list)):
                    default_value = values[0]
        else:
            if default_value is None:
                default_value = values[0]
                assumed_default = True

        # We need tuples or a list for the index
        if not isinstance(values, (tuple, list)):
            values = (values,)

        if aliases:
            if not isinstance(aliases, (tuple, list)):
                aliases = (aliases,)

            if len(aliases) != len(values):
                raise ValueError("Length of aliases does not match "
                                 "length of values.")

        if default_value not in values:
            raise ValueError("Default value not in values.")

        self._validator = TraitEnum(default_value, values, **metadata)

        # Put iotype in the metadata dictionary
        if iotype is not None:
            metadata['iotype'] = iotype

        # Put desc in the metadata dictionary
        if desc is not None:
            metadata['desc'] = desc

        # Put values in the metadata dictionary
        if values:
            metadata['values'] = values

            # We also need to store the values in a dict, to get around
            # a weak typechecking (i.e., enum of [1,2,3] can be 1.0)
            self.valuedict = {}

            for val in values:
                self.valuedict[val] = val

        # Put aliases in the metadata dictionary
        if aliases:
            metadata['aliases'] = aliases

        if 'assumed_default' in metadata:
            del metadata['assumed_default']

        super(Enum, self).__init__(default_value=default_value,
                                   assumed_default=assumed_default, **metadata)

    def get_attribute(self, name, value, trait, meta):
        """Return the attribute dictionary for this variable. This dict is
        used by the GUI to populate the edit UI.

        name: str
          Name of variable

        value: object
          The value of the variable

        trait: CTrait
          The variable's trait

        meta: dict
          Dictionary of metadata for this variable
        """

        attr = {}

        attr['name'] = name
        attr['type'] = "enum"
        attr['value'] = value

        for field in meta:
            if field not in gui_excludes:
                attr[field] = meta[field]

        attr['value_types'] = [type(val).__name__ for val in meta['values']]

        return attr, None

    def validate(self, obj, name, value):
        """ Validates that a specified value is valid for this trait."""

        try:
            val = self._validator.validate(obj, name, value)
        except Exception:
            self.error(obj, name, value)

        # if someone uses a float to set an int-valued Enum, we want it to
        # be an int. Enthought's Enum allows a float value, unfortunately.
        return self.valuedict[val]

    def error(self, obj, name, value):
        """Returns a general error string for Enum."""

        # pylint: disable=E1101
        vtype = type(value)
        if value not in self.values:
            info = str(self.values)
            msg = "Variable '%s' must be in %s, " % (name, info) + \
                "but a value of %s %s was specified." % (value, vtype)
        else:
            msg = "Unknown error while setting trait '%s';" % (name) +\
                  "a value of %s %s was specified." % (value, vtype)

        try:
            obj.raise_exception(msg, ValueError)
        except AttributeError:
            raise ValueError(msg)
Esempio n. 32
0
class ToolkitEditorFactory(EditorFactory):
    """ Editor factory for tree editors.
    """
    #-------------------------------------------------------------------------
    #  Trait definitions:
    #-------------------------------------------------------------------------

    #: Supported TreeNode objects
    nodes = List(TreeNode)

    #: Mapping from TreeNode tuples to MultiTreeNodes
    multi_nodes = Dict

    #: The column header labels if any.
    column_headers = List(Str)

    #: Are the individual nodes editable?
    editable = Bool(True)

    #: Selection mode.
    selection_mode = Enum('single', 'extended')

    #: Is the editor shared across trees?
    shared_editor = Bool(False)

    #: Reference to a shared object editor
    editor = Instance(EditorFactory)

    # FIXME: Implemented only in wx backend.
    #: The DockWindow graphical theme
    dock_theme = Instance(DockWindowTheme)

    #: Show icons for tree nodes?
    show_icons = Bool(True)

    #: Hide the tree root node?
    hide_root = Bool(False)

    #: Layout orientation of the tree and the editor
    orientation = Orientation

    #: Number of tree levels (down from the root) that should be automatically
    #: opened
    auto_open = Int

    #: Size of the tree node icons
    icon_size = IconSize

    #: Called when a node is selected
    on_select = Any

    #: Called when a node is clicked
    on_click = Any

    #: Called when a node is double-clicked
    on_dclick = Any

    #: Called when a node is activated
    on_activated = Any

    #: Call when the mouse hovers over a node
    on_hover = Any

    #: The optional extended trait name of the trait to synchronize with the
    #: editor's current selection:
    selected = Str

    #: The optional extended trait name of the trait that should be assigned
    #: a node object when a tree node is activated, by double-clicking or
    #: pressing the Enter key when a node has focus (Note: if you want to
    #: receive repeated activated events on the same node, make sure the trait
    #: is defined as an Event):
    activated = Str

    #: The optional extended trait name of the trait that should be assigned
    #: a node object when a tree node is clicked on (Note: If you want to
    #: receive repeated clicks on the same node, make sure the trait is defined
    #: as an Event):
    click = Str

    #: The optional extended trait name of the trait that should be assigned
    #: a node object when a tree node is double-clicked on (Note: if you want to
    #: receive repeated double-clicks on the same node, make sure the trait is
    #: defined as an Event):
    dclick = Str

    #: The optional extended trait name of the trait event that is fired
    #: whenever the application wishes to veto a tree action in progress (e.g.
    #: double-clicking a non-leaf tree node normally opens or closes the node,
    #: but if you are handling the double-click event in your program, you may
    #: wish to veto the open or close operation). Be sure to fire the veto event
    #: in the event handler triggered by the operation (e.g. the 'dclick' event
    #: handler.
    veto = Str

    #: The optional extended trait name of the trait event that is fired when the
    #: application wishes the currently visible portion of the tree widget to
    #: repaint itself.
    refresh = Str

    #: Mode for lines connecting tree nodes
    #:
    #: * 'appearance': Show lines only when they look good.
    #: * 'on': Always show lines.
    #: * 'off': Don't show lines.
    lines_mode = Enum('appearance', 'on', 'off')

    # FIXME: Document as unimplemented or wx specific.
    #: Whether to alternate row colors or not.
    alternating_row_colors = Bool(False)

    #: Any extra vertical padding to add.
    vertical_padding = Int(0)

    #: Whether or not to expand on a double-click.
    expands_on_dclick = Bool(True)

    #: Whether the labels should be wrapped around, if not an ellipsis is shown
    #: This works only in the qt backend and if there is only one column in tree
    word_wrap = Bool(False)