コード例 #1
0
class SourceAdderNode(ListAdderNode):
    """ Tree node that presents a view to the user to add a scene source.
    """

    # Button for adding a data file, with automatic format checking.
    open_file = ToolbarButton('Load data from file',
                                orientation='horizontal',
                                image=ImageResource('file.png'))

    # A reference to the registry, to generate this list.
    items_list_source = [source for source in registry.sources
                         if len(source.extensions) == 0]

    # The string to display on the icon in the TreeEditor.
    label = 'Add Data Source'

    # The icon of the displayed objects
    icon_name = Str('source.ico')

    # Trait view to show in the Mayavi current object panel.
    def default_traits_view(self):
        return View(Group(Group(Item('open_file', style='custom'),
                      show_labels=False, show_border=False),
                      Item('items_list', style='readonly',
                            editor=ListEditor(style='custom')),
                      show_labels=False,
                      label='Add a data source'))

    def _open_file_fired(self):
        """ Trait handler for when the open_file button is clicked.
        """
        self.object.menu_helper.open_file_action()

    def _is_action_suitable(self, object, src):
        return True
コード例 #2
0
ファイル: adder_node.py プロジェクト: zaherabdulazeez/mayavi
class DocumentedItem(HasTraits):
    """ Container to hold a name and a documentation for an action.
    """

    # Name of the action
    name = Str

    # Button to trigger the action
    add = ToolbarButton('Add',
                        orientation='horizontal',
                        image=ImageResource('add.ico'))

    # Object the action will apply on
    object = Any

    # Two lines documentation for the action
    documentation = Str

    view = View(
        '_',
        Item('add', style='custom', show_label=False),
        Item('documentation',
             style='readonly',
             editor=TextEditor(multi_line=True),
             resizable=True,
             show_label=False),
    )

    def _add_fired(self):
        """ Trait handler for when the add_source button is clicked in
            one of the sub objects in the list.
        """
        action = getattr(self.object.menu_helper, self.id)
        action()
コード例 #3
0
ファイル: coder.py プロジェクト: imagect/imagect
class CodeEditorDemo(HasTraits):
    """ Defines the CodeEditor demo class.
    """

    run = ToolbarButton()

    def _run_fired(self):
        import imagect.api.util as iu
        iu.console.execute(self.code_sample)

    toolbar = Group(Item("run", show_label=False))

    # Define a trait to view:
    code_sample = Code("""%matplotlib
import imagect.api.util as iu
from matplotlib import pyplot as plt
import numpy as np
X = np.linspace(-np.pi, np.pi, 256, endpoint=True)
C, S = np.cos(X), np.sin(X)
plt.plot(X, C)
plt.plot(X, S)
plt.show()
""")

    # Display specification:
    code_group = Group(
        Item('code_sample', style='custom', label='Custom',
             show_label=False), )

    # Demo view:
    view = View(VGroup(toolbar, code_group),
                title='CodeEditor',
                buttons=['OK'])
コード例 #4
0
class TargetPosition(HasTraits):
    """Defines a nice GUI for setting position value of the translator.
    It defines buttons for step up, step down, step up 10x, step down
    10x. Step size is determined by the :attr:`~.Position.step`.
    
    Examples
    --------
    
    >>> p = TargetPosition(value = 0.5, step = 0.2)
    >>> p._less_fired() #called on button '<' pressed: step down by 0.2
    >>> p.value 
    0.3
    >>> p._much_more_fired() #called on button '>>>' pressed: step up by 10x0.2
    >>> p.value 
    2.3
    """
    #: step size
    step = Float(1.)
    #: position value float
    #value = Float(0., desc = 'target position')
    value = Range(low='low', high='high', value=0)
    low = Float(-1000.)
    high = Float(1000.)
    less = ToolbarButton(' < ', desc='step down')
    more = ToolbarButton(' > ', desc='step up')
    much_less = ToolbarButton('<<<', desc='step down 10x')
    much_more = ToolbarButton('>>>', desc='step up 10x')

    button_fired = Event()

    view = position_view

    def _much_less_fired(self):
        self.value = self.value - self.step * 10
        self.button_fired = True

    def _much_more_fired(self):
        self.value = self.value + self.step * 10
        self.button_fired = True

    def _less_fired(self):
        self.value = self.value - self.step
        self.button_fired = True

    def _more_fired(self):
        self.value = self.value + self.step
        self.button_fired = True
コード例 #5
0
ファイル: rtrace.py プロジェクト: simvisage/bmcs
class RTrace(HasStrictTraits):
    name = Str('unnamed')
    record_on = Enum('update', 'iteration')
    clear_on = Enum('never', 'update')
    save_on = Enum(None)

    #sctx = WeakRef( SContext )
    rmgr = WeakRef(trantient=True)
    sd = WeakRef(trantient=True)

    # path to directory to store the data
    dir = Property

    def _get_dir(self):
        return self.rmgr.dir

    # path to the file to store the data
    file = File

    def setup(self):
        '''Prepare the tracer for recording.
        '''
        pass

    def close(self):
        '''Close the tracer - save its values to file.
        '''
        pass

    refresh_button = ToolbarButton('Refresh', style='toolbar', trantient=True)

    @on_trait_change('refresh_button')
    def refresh(self, event=None):
        self.redraw()

    def add_current_values(self, sctx, U_k, t, *args, **kw):
        pass

    # TODO: to avoid class checking in rmngr - UGLY
    def add_current_displ(self, sctx, t, U_k):
        pass

    def register_mv_pipelines(self, e):
        '''
コード例 #6
0
class MFnLineArray(HasTraits):

    # Public Traits
    xdata = Array(float, value=[0.0, 1.0])

    def _xdata_default(self):
        '''
        convenience default - when xdata not defined created automatically as
        an array of integers with the same shape as ydata
        '''
        return np.arange(self.ydata.shape[0])

    ydata = Array(float, value=[0.0, 1.0])

    extrapolate = Enum('constant', 'exception', 'diff', 'zero')

    # alternative vectorized interpolation using scipy.interpolate
    def get_values(self, x, k=1):
        '''
        vectorized interpolation, k is the spline order, default set to 1 (linear)
        '''
        tck = ip.splrep(self.xdata, self.ydata, s=0, k=k)

        x = np.array([x]).flatten()

        if self.extrapolate == 'diff':
            values = ip.splev(x, tck, der=0)
        elif self.extrapolate == 'exception':
            if x.all() < self.xdata[0] and x.all() > self.xdata[-1]:
                values = values = ip.splev(x, tck, der=0)
            else:
                raise ValueError('value(s) outside interpolation range')

        elif self.extrapolate == 'constant':
            values = ip.splev(x, tck, der=0)
            values[x < self.xdata[0]] = self.ydata[0]
            values[x > self.xdata[-1]] = self.ydata[-1]
        elif self.extrapolate == 'zero':
            values = ip.splev(x, tck, der=0)
            values[x < self.xdata[0]] = 0.0
            values[x > self.xdata[-1]] = 0.0
        return values

    def get_value(self, x):
        x2idx = self.xdata.searchsorted(x)
        if x2idx == len(self.xdata):
            x2idx -= 1
        x1idx = x2idx - 1
        x1 = self.xdata[x1idx]
        x2 = self.xdata[x2idx]
        dx = x2 - x1
        y1 = self.ydata[x1idx]
        y2 = self.ydata[x2idx]
        dy = y2 - y1
        y = y1 + dy / dx * (x - x1)
        return y

    data_changed = Event

    def get_diffs(self, x, k=1, der=1):
        '''
        vectorized interpolation, der is the nth derivative, default set to 1;
        k is the spline order of the data inetrpolation, default set to 1 (linear)
        '''
        xdata = np.sort(np.hstack((self.xdata, x)))
        idx = np.argwhere(np.diff(xdata) == 0).flatten()
        xdata = np.delete(xdata, idx)
        tck = ip.splrep(xdata, self.get_values(xdata, k=k), s=0, k=k)
        return ip.splev(x, tck, der=der)

    def get_diff(self, x):
        x2idx = self.xdata.searchsorted(x)
        if x2idx == len(self.xdata):
            x2idx -= 1
        x1idx = x2idx - 1
        x1 = self.xdata[x1idx]
        x2 = self.xdata[x2idx]
        dx = x2 - x1
        y1 = self.ydata[x1idx]
        y2 = self.ydata[x2idx]
        dy = y2 - y1
        return dy / dx

    def __call__(self, x):
        return self.get_values(x)

    dump_button = ToolbarButton('Print data',
                                style='toolbar')

    @on_trait_change('dump_button')
    def print_data(self, event=None):
        print 'x = ', repr(self.xdata)
        print 'y = ', repr(self.ydata)

    integ_value = Property(Float(), depends_on='ydata')

    @cached_property
    def _get_integ_value(self):
        _xdata = self.xdata
        _ydata = self.ydata
        # integral under the stress strain curve
        return np.trapz(_ydata, _xdata)

    def clear(self):
        self.xdata = np.array([])
        self.ydata = np.array([])

    def plot(self, axes, *args, **kw):
        self.mpl_plot(axes, *args, **kw)

    def mpl_plot(self, axes, *args, **kw):
        '''plot within matplotlib window'''
        axes.plot(self.xdata, self.ydata, *args, **kw)
コード例 #7
0
ファイル: rt_dof.py プロジェクト: simvisage/cbfe
class RTraceGraph(RTrace):
    '''
    Collects two response evaluators to make a response graph.

    The supplied strings for var_x and var_y are used to locate the rte in
    the current response manager. The bind method is used to navigate to
    the rte and is stored in here as var_x_eval and var_y_val as Callable
    object.

    The request for new response evaluation is launched by the time loop
    and directed futher by the response manager. This method is used solely
    for collecting the data, not for their visualization in the viewer.

    The timer_tick method is invoked when the visualization of the Graph
    should be synchronized with the actual contents.
    '''

    label = Str('RTraceGraph')
    var_x = Str('')
    var_x_eval = Callable(trantient=True)
    idx_x_arr = Array
    idx_x = Int(-1, enter_set=True, auto_set=False)
    var_y = Str('')
    var_y_eval = Callable(trantient=True)
    idx_y_arr = Array
    idx_y = Int(-1, enter_set=True, auto_set=False)
    transform_x = Str(enter_set=True, auto_set=False)
    transform_y = Str(enter_set=True, auto_set=False)

    trace = Instance(MFnLineArray)

    def _trace_default(self):
        return MFnLineArray()

    print_button = ToolbarButton('Print Values',
                                 style='toolbar',
                                 trantient=True)

    @on_trait_change('print_button')
    def print_values(self, event=None):
        print 'x:\t', self.trace.xdata, '\ny:\t', self.trace.ydata

    view = View(VSplit(
        VGroup(
            HGroup(
                VGroup(
                    HGroup(Spring(), Item('var_x', style='readonly'),
                           Item('idx_x', show_label=False)),
                    Item('transform_x')),
                VGroup(
                    HGroup(Spring(), Item('var_y', style='readonly'),
                           Item('idx_y', show_label=False)),
                    Item('transform_y')), VGroup('record_on', 'clear_on')),
            HGroup(Item('refresh_button', show_label=False),
                   Item('print_button', show_label=False)),
        ),
        Item('trace@',
             editor=MFnMatplotlibEditor(adapter=MFnPlotAdapter(
                 var_x='var_x', var_y='var_y', min_size=(100, 100))),
             show_label=False,
             resizable=True),
    ),
                buttons=[OKButton, CancelButton],
                resizable=True,
                scrollable=True,
                height=0.5,
                width=0.5)

    _xdata = List(Array(float))
    _ydata = List(Array(float))

    def bind(self):
        '''
        Locate the evaluators
        '''
        self.var_x_eval = self.rmgr.rte_dict.get(self.var_x, None)
        if self.var_x_eval == None:
            raise KeyError, 'Variable %s not present in the dictionary:\n%s' % \
                            (self.var_x, self.rmgr.rte_dict.keys())

        self.var_y_eval = self.rmgr.rte_dict.get(self.var_y, None)
        if self.var_y_eval == None:
            raise KeyError, 'Variable %s not present in the dictionary:\n%s' % \
                            (self.var_y, self.rmgr.rte_dict.keys())

    def setup(self):
        self.clear()

    def close(self):
        self.write()

    def write(self):
        '''Generate the file name within the write_dir
        and submit the request for writing to the writer
        '''
        # self.writer.scalars_name = self.name
        file_base_name = 'rtrace_diagramm_%s (%s,%s).dat' % \
            (self.label, self.var_x, self.var_y)
        # full path to the data file
        file_name = os.path.join(self.dir, file_base_name)
        # file_rtrace = open( file_name, 'w' )
        self.refresh()
        np.savetxt(file_name,
                   np.vstack([self.trace.xdata, self.trace.ydata]).T)
        # pickle.dump( self, file_rtrace )
        # file.close()

    def add_current_values(self, sctx, U_k, *args, **kw):
        '''
        Invoke the evaluators in the current context for the specified control vector U_k.
        '''
        x = self.var_x_eval(sctx, U_k, *args, **kw)
        y = self.var_y_eval(sctx, U_k, *args, **kw)

        self.add_pair(x.flatten(), y.flatten())

    def add_pair(self, x, y):
        self._xdata.append(np.copy(x))
        self._ydata.append(np.copy(y))

    @on_trait_change('idx_x,idx_y')
    def redraw(self, e=None):
        if ((self.idx_x < 0 and len(self.idx_x_arr) == 0)
                or (self.idx_y < 0 and len(self.idx_y_arr) == 0)
                or self._xdata == [] or self._ydata == []):
            return
        #
        if len(self.idx_x_arr) > 0:
            print 'x: summation for', self.idx_x_arr
            xarray = np.array(self._xdata)[:, self.idx_x_arr].sum(1)
        else:
            xarray = np.array(self._xdata)[:, self.idx_x]

        if len(self.idx_y_arr) > 0:
            print 'y: summation for', self.idx_y_arr
            yarray = np.array(self._ydata)[:, self.idx_y_arr].sum(1)


#            print 'yarray', yarray
#            yarray_arr = array( self._ydata )[:, self.idx_y_arr]
#            sym_weigth_arr = 2. * ones_like( yarray_arr[1] )
#            sym_weigth_arr[0] = 4.
#            print 'yarray_arr', yarray_arr
#            print 'sym_weigth_arr', sym_weigth_arr
#            yarray = dot( yarray_arr, sym_weigth_arr )
#            print 'yarray', yarray

        else:
            yarray = np.array(self._ydata)[:, self.idx_y]

        if self.transform_x:

            def transform_x_fn(x):
                '''makes a callable function out of the Str-attribute
                "transform_x". The vectorised version of this function is
                then used to transform the values in "xarray". Note that
                the function defined in "transform_x" must be defined in
                terms of a lower case variable "x".
                '''
                return eval(self.transform_x)

            xarray = np.frompyfunc(transform_x_fn, 1, 1)(xarray)

        if self.transform_y:

            def transform_y_fn(y):
                '''makes a callable function out of the Str-attribute
                "transform_y". The vectorised version of this function is
                then used to transform the values in "yarray". Note that
                the function defined in "transform_y" must be defined in
                terms of a lower case variable "y".
                '''
                return eval(self.transform_y)

            yarray = np.frompyfunc(transform_y_fn, 1, 1)(yarray)

        self.trace.xdata = np.array(xarray)
        self.trace.ydata = np.array(yarray)
        self.trace.data_changed = True

    def timer_tick(self, e=None):
        # @todo: unify with redraw
        pass

    def clear(self):
        self._xdata = []
        self._ydata = []
        self.trace.clear()
        self.redraw()
コード例 #8
0
ファイル: adder_node.py プロジェクト: zaherabdulazeez/mayavi
 class MyDocumentedItem(DocumentedItem):
     add = ToolbarButton('%s' % name,
                         orientation='horizontal',
                         image=ImageResource('add.ico'))
コード例 #9
0
ファイル: rt_dof.py プロジェクト: simvisage/bmcs
class RTDofGraph(RTrace, BMCSLeafNode, Vis2D):
    '''
    Collects two response evaluators to make a response graph.

    The supplied strings for var_x and var_y are used to locate the rte in
    the current response manager. The bind method is used to navigate to
    the rte and is stored in here as var_x_eval and var_y_val as Callable
    object.

    The request for new response evaluation is launched by the time loop
    and directed futher by the response manager. This method is used solely
    for collecting the data, not for their visualization in the viewer.

    The timer_tick method is invoked when the visualization of the Graph
    should be synchronized with the actual contents.
    '''

    label = Str('RTDofGraph')
    var_x = Str('', label='Variable on x', enter_set=True, auto_set=False)
    cum_x = Bool(label='Cumulative x', enter_set=True, auto_set=False)
    var_x_eval = Callable(trantient=True)
    idx_x = Int(-1, enter_set=True, auto_set=False)
    var_y = Str('', label='Variable on y', enter_set=True, auto_set=False)
    cum_y = Bool(label='Cumulative y', enter_set=True, auto_set=False)
    var_y_eval = Callable(trantient=True)
    idx_y = Int(-1, enter_set=True, auto_set=False)
    transform_x = Str(enter_set=True, auto_set=False)
    transform_y = Str(enter_set=True, auto_set=False)

    trace = Instance(MFnLineArray)
    _tdata = List(np.float)

    def _trace_default(self):
        return MFnLineArray()

    print_button = ToolbarButton('Print values',
                                 style='toolbar',
                                 trantient=True)

    @on_trait_change('print_button')
    def print_values(self, event=None):
        print('x:\t', self.trace.xdata, '\ny:\t', self.trace.ydata)

    _xdata = List(Array(float))
    _ydata = List(Array(float))

    def bind(self):
        '''
        Locate the evaluators
        '''
        self.var_x_eval = self.rmgr.rte_dict.get(self.var_x, None)
        if self.var_x_eval == None:
            raise KeyError('Variable %s not present in the dictionary:\n%s' % \
                            (self.var_x, list(self.rmgr.rte_dict.keys())))

        self.var_y_eval = self.rmgr.rte_dict.get(self.var_y, None)
        if self.var_y_eval == None:
            raise KeyError('Variable %s not present in the dictionary:\n%s' % \
                            (self.var_y, list(self.rmgr.rte_dict.keys())))

    def setup(self):
        self.clear()

    def close(self):
        self.write()

    def write(self):
        '''Generate the file name within the write_dir
        and submit the request for writing to the writer
        '''
        # self.writer.scalars_name = self.name
        file_base_name = 'rtrace_diagramm_%s (%s,%s).dat' % \
            (self.label, self.var_x, self.var_y)
        # full path to the data file
        file_name = os.path.join(self.dir, file_base_name)
        # file_rtrace = open( file_name, 'w' )
        self.refresh()
        np.savetxt(file_name,
                   np.vstack([self.trace.xdata, self.trace.ydata]).T)
        # pickle.dump( self, file_rtrace )
        # file.close()

    def add_current_values(self, sctx, U_k, t, *args, **kw):
        '''
        Invoke the evaluators in the current context for the specified control vector U_k.
        '''

        x = self.var_x_eval(sctx, U_k, *args, **kw)
        y = self.var_y_eval(sctx, U_k, *args, **kw)

        self.add_pair(x.flatten(), y.flatten(), t)

    def add_pair(self, x, y, t):

        if self.cum_x and len(self._xdata) > 0:
            self._xdata.append(self._xdata[-1] + x)
        else:
            self._xdata.append(np.copy(x))
        if self.cum_y and len(self._ydata) > 0:
            self._ydata.append(self._ydata[-1] + y)
        else:
            self._ydata.append(np.copy(y))
        self._tdata.append(t)

    @on_trait_change('idx_x,idx_y')
    def redraw(self, e=None):
        if (self._xdata == [] or self._ydata == []):
            return
        #
        xarray = np.array(self._xdata)[:, self.idx_x]
        yarray = np.array(self._ydata)[:, self.idx_y]

        if self.transform_x:

            def transform_x_fn(x):
                '''makes a callable function out of the Str-attribute
                "transform_x". The vectorised version of this function is
                then used to transform the values in "xarray". Note that
                the function defined in "transform_x" must be defined in
                terms of a lower case variable "x".
                '''
                return eval(self.transform_x)

            xarray = np.frompyfunc(transform_x_fn, 1, 1)(xarray)

        if self.transform_y:

            def transform_y_fn(y):
                '''makes a callable function out of the Str-attribute
                "transform_y". The vectorised version of this function is
                then used to transform the values in "yarray". Note that
                the function defined in "transform_y" must be defined in
                terms of a lower case variable "y".
                '''
                return eval(self.transform_y)

            yarray = np.frompyfunc(transform_y_fn, 1, 1)(yarray)

        self.trace.xdata = np.array(xarray)
        self.trace.ydata = np.array(yarray)
        self.trace.replot()

    def timer_tick(self, e=None):
        # @todo: unify with redraw
        pass

    def clear(self):
        self._xdata = []
        self._ydata = []
        self._tdata = []
        self.trace.clear()
        self.redraw()

    viz2d_classes = {'diagram': RTraceViz2D}

    traits_view = View(VSplit(
        VGroup(
            HGroup(
                VGroup(
                    HGroup(Spring(), Item('var_x', style='readonly'),
                           Item('idx_x', show_label=False)),
                    Item('transform_x')),
                VGroup(
                    HGroup(Spring(), Item('var_y', style='readonly'),
                           Item('idx_y', show_label=False)),
                    Item('transform_y')), VGroup('record_on', 'clear_on')),
            HGroup(Item('refresh_button', show_label=False),
                   Item('print_button', show_label=False)),
        ),
        Item('trace@', show_label=False, resizable=True),
    ),
                       buttons=[OKButton, CancelButton],
                       resizable=True,
                       scrollable=True,
                       height=0.5,
                       width=0.5)

    tree_view = View(
        Include('actions'),
        Item('var_x', style='readonly'),
        Item('idx_x', show_label=False),
    )
コード例 #10
0
class MFnLineArray(BMCSLeafNode):

    # Public Traits
    xdata = Array(float, value=[0.0, 1.0])

    def _xdata_default(self):
        '''
        convenience default - when xdata not defined created automatically as
        an array of integers with the same shape as ydata
        '''
        return np.arange(self.ydata.shape[0])

    ydata = Array(float, value=[0.0, 1.0])

    def __init__(self, *args, **kw):
        super(MFnLineArray, self).__init__(*args, **kw)
        self.replot()

    extrapolate = Enum('constant', 'exception', 'diff', 'zero')
    '''
    Vectorized interpolation using scipy.interpolate
    '''

    def values(self, x, k=1):
        '''
        vectorized interpolation, k is the spline order, default set to 1 (linear)
        '''
        tck = ip.splrep(self.xdata, self.ydata, s=0, k=k)

        x = np.array([x]).flatten()

        if self.extrapolate == 'diff':
            values = ip.splev(x, tck, der=0)
        elif self.extrapolate == 'exception':
            if x.all() < self.xdata[0] and x.all() > self.xdata[-1]:
                values = values = ip.splev(x, tck, der=0)
            else:
                raise ValueError('value(s) outside interpolation range')
        elif self.extrapolate == 'constant':
            values = ip.splev(x, tck, der=0)
            values[x < self.xdata[0]] = self.ydata[0]
            values[x > self.xdata[-1]] = self.ydata[-1]
        elif self.extrapolate == 'zero':
            values = ip.splev(x, tck, der=0)
            values[x < self.xdata[0]] = 0.0
            values[x > self.xdata[-1]] = 0.0
        return values

    def __call__(self, x):
        return self.values(x)

    yrange = Property
    '''Get min max values on the vertical axis
    '''

    def _get_yrange(self):
        return np.min(self.ydata), np.max(self.ydata)

    xrange = Property
    '''Get min max values on the vertical axis
    '''

    def _get_xrange(self):
        return np.min(self.xdata), np.max(self.xdata)

    data_changed = Event

    figure = Instance(Figure)

    def _figure_default(self):
        figure = Figure(facecolor='white')
        return figure

    def diff(self, x, k=1, der=1):
        '''
        vectorized interpolation, der is the nth derivative, default set to 1;
        k is the spline order of the data inetrpolation, default set to 1 (linear)
        '''
        xdata = np.sort(np.hstack((self.xdata, x)))
        idx = np.argwhere(np.diff(xdata) == 0).flatten()
        xdata = np.delete(xdata, idx)
        tck = ip.splrep(xdata, self.values(xdata, k=k), s=0, k=k)
        return ip.splev(x, tck, der=der)

    dump_button = ToolbarButton('Print data', style='toolbar')

    @on_trait_change('dump_button')
    def print_data(self, event=None):
        print('x = ', repr(self.xdata))
        print('y = ', repr(self.ydata))

    integ = Property(Float(), depends_on='ydata')

    @cached_property
    def _get_integ(self):
        _xdata = self.xdata
        _ydata = self.ydata
        # integral under the stress strain curve
        return np.trapz(_ydata, _xdata)

    def clear(self):
        self.xdata = np.array([])
        self.ydata = np.array([])

    def plot(self, axes, *args, **kw):
        self.mpl_plot(axes, *args, **kw)

    def mpl_plot(self, axes, *args, **kw):
        '''plot within matplotlib window'''
        axes.plot(self.xdata, self.ydata, *args, **kw)

    def mpl_plot_diff(self, axes, *args, **kw):
        '''plot within matplotlib window'''
        ax_dx = axes.twinx()
        x = np.linspace(self.xdata[0], self.xdata[-1],
                        np.size(self.xdata) * 20.0)
        y_dx = self.diff(x, k=1, der=1)
        ax_dx.plot(x, y_dx, *args + ('-', ), **kw)

    plot_diff = Bool(False)

    def replot(self):
        self.figure.clf()
        ax = self.figure.add_subplot(111)
        self.mpl_plot(ax)
        if self.plot_diff:
            self.mpl_plot_diff(ax, color='orange')
        self.data_changed = True

    def savefig(self, fname):
        fig = plt.figure()
        ax = fig.add_subplot(111)
        self.mpl_plot(ax)
        self.mpl_plot_diff(ax, color='orange')
        fig.savefig(fname)

    tree_view = View(
        VGroup(
            VGroup(
                #                 UItem('figure', editor=MPLFigureEditor(),
                #                       resizable=True,
                #                       springy=True),
                #                 scrollable=True,
            ), ))

    traits_view = tree_view
コード例 #11
0
class DataFrameAnalyzerView(ModelView):
    """ Flexible ModelView class for a DataFrameAnalyzer.

    The view is built using many methods to build each component of the view so
    it can easily be subclassed and customized.

    TODO: add traits events to pass update/refresh notifications to the
     DFEditors once we have updated TraitsUI.

    TODO: Add traits events to receive notifications that a column/row was
     clicked/double-clicked.
    """
    #: Model being viewed
    model = Instance(DataFrameAnalyzer)

    #: Selected list of data columns to display and analyze
    visible_columns = List(Str)

    #: Check box to hide/show what stats are included in the summary DF
    show_summary_controls = Bool

    #: Show the summary categorical df
    show_categorical_summary = Bool(True)

    #: Check box to hide/show what columns to analyze (panel when few columns)
    show_column_controls = Bool

    #: Open control for what columns to analyze (popup when many columns)
    open_column_controls = Button("Show column control")

    #: Button to launch the plotter tool when plotter_layout=popup
    plotter_launcher = Button("Launch Plot Tool")

    # Plotting tool attributes ------------------------------------------------

    #: Does the UI expose a DF plotter?
    include_plotter = Bool

    #: Plot manager view to display. Ignored if include_plotter is False.
    plotter = Instance(DataFramePlotManagerView)

    # Styling and branding attributes -----------------------------------------

    #: String describing the font to use, or dict mapping column names to font
    fonts = Either(Str, Dict)

    #: Name of the font to use if same across all columns
    font_name = Str(DEFAULT_FONT)

    #: Size of the font to use if same across all columns
    font_size = Int(14)

    #: Number of digits to display in the tables
    display_precision = Int(-1)

    #: Formatting to use to include
    formats = Either(Str, Dict)

    #: UI title for the Data section
    data_section_title = Str("Data")

    #: Exploration group label: visible only when plotter_layout="Tabbed"
    exploration_group_label = Str("Exploration Tools")

    #: Plotting group label: visible only when plotter_layout="Tabbed"
    plotting_group_label = Str("Plotting Tools")

    #: UI title for the data summary section
    summary_section_title = Str

    #: UI title for the categorical data summary section
    cat_summary_section_title = Str("Categorical data summary")

    #: UI title for the column list section
    column_list_section_title = Str("Column content")

    #: UI title for the summary content section
    summary_content_section_title = Str("Summary content")

    #: UI summary group (tab) name for numerical columns
    num_summary_group_name = Str("Numerical data")

    #: UI summary group (tab) name for categorical columns
    cat_summary_group_name = Str("Categorical data")

    #: Text to display in title bar of the containing window (if applicable)
    app_title = Str("Tabular Data Analyzer")

    #: How to place the plotter tool with respect to the exploration tool?
    plotter_layout = Enum("Tabbed", "HSplit", "VSplit", "popup")

    #: DFPlotManager traits to customize it
    plotter_kw = Dict

    #: Message displayed below the table if truncated
    truncation_msg = Property(Str, depends_on="model.num_displayed_rows")

    # Functionality controls --------------------------------------------------

    #: Button to shuffle the order of the filtered data
    shuffle_button = Button("Shuffle")

    show_shuffle_button = Bool(True)

    #: Button to display more rows in the data table
    show_more_button = Button

    #: Button to display all rows in the data table
    show_all_button = Button("Show All")

    #: Apply button for the filter if model not in auto-apply mode
    apply_filter_button = ToolbarButton(image=apply_img)

    #: Edit the filter in a pop-out dialog
    pop_out_filter_button = ToolbarButton(image=pop_out_img)

    #: Whether to support saving, and loading filters
    filter_manager = Bool

    #: Button to launch filter expression manager to load an existing filter
    load_filter_button = ToolbarButton(image=load_img)

    #: Button to save current filter expression
    save_filter_button = ToolbarButton(image=save_img)

    #: Button to launch filter expression manager to modify saved filters
    manage_filter_button = ToolbarButton(image=manage_img)

    #: List of saved filtered expressions
    _known_expr = Property(Set, depends_on="model.known_filter_exps")

    #: Show the bottom panel with the summary of the data:
    _show_summary = Bool(True)

    allow_show_summary = Bool(True)

    #: Button to export the analyzed data to a CSV file
    data_exporter = Button("Export Data to CSV")

    #: Button to export the summary data to a CSV file
    summary_exporter = Button("Export Summary to CSV")

    # Detailed configuration traits -------------------------------------------

    #: View class to use. Modify to customize.
    view_klass = Any(View)

    #: Width of the view
    view_width = Int(1100)

    #: Height of the view
    view_height = Int(700)

    #: Width of the filter box
    filter_item_width = Int(400)

    max_names_per_column = Int(12)

    truncation_msg_template = Str("Table truncated at {} rows")

    warn_if_sel_hidden = Bool(True)

    hidden_selection_msg = Str

    #: Column names (as a list) to include in filter editor assistant
    filter_editor_cols = List

    # Implementation details --------------------------------------------------

    #: Evaluate number of columns to select panel or popup column control
    _many_columns = Property(Bool, depends_on="model.column_list")

    #: Popped-up UI to control the visible columns
    _control_popup = Any

    #: Collected traitsUI editors for both the data DF and the summary DF
    _df_editors = Dict

    # HasTraits interface -----------------------------------------------------

    def __init__(self, **traits):
        if "model" in traits and isinstance(traits["model"], pd.DataFrame):
            traits["model"] = DataFrameAnalyzer(source_df=traits["model"])

        super(DataFrameAnalyzerView, self).__init__(**traits)

        if self.include_plotter:
            # If a plotter view was specified, its model should be in the
            # model's list of plot managers:
            if self.plotter.model not in self.model.plot_manager_list:
                self.model.plot_manager_list.append(self.plotter.model)

    def traits_view(self):
        """ Putting the view components together.

        Each component of the view is built in a separate method so it can
        easily be subclassed and customized.
        """
        # Construction of view groups -----------------------------------------

        data_group = self.view_data_group_builder()
        column_controls_group = self.view_data_control_group_builder()
        summary_group = self.view_summary_group_builder()
        summary_controls_group = self.view_summary_control_group_builder()
        if self.show_categorical_summary:
            cat_summary_group = self.view_cat_summary_group_builder()
        else:
            cat_summary_group = None
        plotter_group = self.view_plotter_group_builder()

        button_content = [
            Item("data_exporter", show_label=False),
            Spring(),
            Item("summary_exporter", show_label=False)
        ]

        if self.plotter_layout == "popup":
            button_content += [
                Spring(),
                Item("plotter_launcher", show_label=False)
            ]

        button_group = HGroup(*button_content)

        # Organization of item groups -----------------------------------------

        # If both types of summary are available, display as Tabbed view:
        if summary_group is not None and cat_summary_group is not None:
            summary_container = Tabbed(
                HSplit(
                    summary_controls_group,
                    summary_group,
                    label=self.num_summary_group_name
                ),
                cat_summary_group,
            )
        elif cat_summary_group is not None:
            summary_container = cat_summary_group
        else:
            summary_container = HSplit(
                summary_controls_group,
                summary_group
            )

        # Allow to hide all summary information:
        summary_container.visible_when = "_show_summary"

        exploration_groups = VGroup(
            VSplit(
                HSplit(
                    column_controls_group,
                    data_group,
                ),
                summary_container
            ),
            button_group,
            label=self.exploration_group_label
        )

        if self.include_plotter and self.plotter_layout != "popup":
            layout = getattr(traitsui.api, self.plotter_layout)
            groups = layout(
                exploration_groups,
                plotter_group
            )
        else:
            groups = exploration_groups

        view = self.view_klass(
            groups,
            resizable=True,
            title=self.app_title,
            width=self.view_width, height=self.view_height
        )
        return view

    # Traits view building methods --------------------------------------------

    def view_data_group_builder(self):
        """ Build view element for the Data display
        """
        editor_kw = dict(show_index=True, columns=self.visible_columns,
                         fonts=self.fonts, formats=self.formats)
        data_editor = DataFrameEditor(selected_row="selected_idx",
                                      multi_select=True, **editor_kw)

        filter_group = HGroup(
            Item("model.filter_exp", label="Filter",
                 width=self.filter_item_width),
            Item("pop_out_filter_button", show_label=False, style="custom",
                 tooltip="Open filter editor..."),
            Item("apply_filter_button", show_label=False,
                 visible_when="not model.filter_auto_apply", style="custom",
                 tooltip="Apply current filter"),
            Item("save_filter_button", show_label=False,
                 enabled_when="model.filter_exp not in _known_expr",
                 visible_when="filter_manager", style="custom",
                 tooltip="Save current filter"),
            Item("load_filter_button", show_label=False,
                 visible_when="filter_manager", style="custom",
                 tooltip="Load a filter..."),
            Item("manage_filter_button", show_label=False,
                 visible_when="filter_manager", style="custom",
                 tooltip="Manage filters..."),
        )

        truncated = ("len(model.displayed_df) < len(model.filtered_df) and "
                     "not model.show_selected_only")
        more_label = "Show {} More".format(self.model.num_display_increment)
        display_control_group = HGroup(
            Item("model.show_selected_only", label="Selected rows only"),
            Item("truncation_msg", style="readonly", show_label=False,
                 visible_when=truncated),
            Item("show_more_button", editor=ButtonEditor(label=more_label),
                 show_label=False, visible_when=truncated),
            Item("show_all_button", show_label=False,
                 visible_when=truncated),
        )

        data_group = VGroup(
            make_window_title_group(self.data_section_title, title_size=3,
                                    include_blank_spaces=False),
            HGroup(
                Item("model.sort_by_col", label="Sort by"),
                Item("shuffle_button", show_label=False,
                     visible_when="show_shuffle_button"),
                Spring(),
                filter_group
            ),
            HGroup(
                Item("model.displayed_df", editor=data_editor,
                     show_label=False),
            ),
            HGroup(
                Item("show_column_controls",
                     label="\u2190 Show column control",
                     visible_when="not _many_columns"),
                Item("open_column_controls", show_label=False,
                     visible_when="_many_columns"),
                Spring(),
                Item("_show_summary", label=u'\u2193 Show summary',
                     visible_when="allow_show_summary"),
                Spring(),
                display_control_group
            ),
            show_border=True
        )
        return data_group

    def view_data_control_group_builder(self, force_visible=False):
        """ Build view element for the Data column control.

        Parameters
        ----------
        force_visible : bool
            Controls visibility of the created group. Don't force for the group
            embedded in the global view, but force it when opened as a popup.
        """
        num_cols = 1 + len(self.model.column_list) // self.max_names_per_column

        column_controls_group = VGroup(
            make_window_title_group(self.column_list_section_title,
                                    title_size=3, include_blank_spaces=False),
            Item("visible_columns", show_label=False,
                 editor=CheckListEditor(values=self.model.column_list,
                                        cols=num_cols),
                 # The custom style allows to control a list of options rather
                 # than having a checklist editor for a single value:
                 style='custom'),
            show_border=True
        )
        if force_visible:
            column_controls_group.visible_when = ""
        else:
            column_controls_group.visible_when = "show_column_controls"

        return column_controls_group

    def view_summary_group_builder(self):
        """ Build view element for the numerical data summary display
        """
        editor_kw = dict(show_index=True, columns=self.visible_columns,
                         fonts=self.fonts, formats=self.formats)
        summary_editor = DataFrameEditor(**editor_kw)

        summary_group = VGroup(
            make_window_title_group(self.summary_section_title, title_size=3,
                                    include_blank_spaces=False),
            Item("model.summary_df", editor=summary_editor, show_label=False,
                 visible_when="len(model.summary_df) != 0"),
            # Workaround the fact that the Label's visible_when is buggy:
            # encapsulate it into a group and add the visible_when to the group
            HGroup(
                Label("No data columns with numbers were found."),
                visible_when="len(model.summary_df) == 0"
            ),
            HGroup(
                Item("show_summary_controls"),
                Spring(),
                visible_when="len(model.summary_df) != 0"
            ),
            show_border=True,
        )
        return summary_group

    def view_summary_control_group_builder(self):
        """ Build view element for the column controls for data summary.
        """
        summary_controls_group = VGroup(
            make_window_title_group(self.summary_content_section_title,
                                    title_size=3, include_blank_spaces=False),
            Item("model.summary_index", show_label=False),
            visible_when="show_summary_controls",
            show_border=True
        )

        return summary_controls_group

    def view_cat_summary_group_builder(self):
        """ Build view element for the categorical data summary display.
        """
        editor_kw = dict(show_index=True, fonts=self.fonts,
                         formats=self.formats)
        summary_editor = DataFrameEditor(**editor_kw)

        cat_summary_group = VGroup(
            make_window_title_group(self.cat_summary_section_title,
                                    title_size=3, include_blank_spaces=False),
            Item("model.summary_categorical_df", editor=summary_editor,
                 show_label=False,
                 visible_when="len(model.summary_categorical_df)!=0"),
            # Workaround the fact that the Label's visible_when is buggy:
            # encapsulate it into a group and add the visible_when to the group
            HGroup(
                Label("No data columns with numbers were found."),
                visible_when="len(model.summary_categorical_df)==0"
            ),
            show_border=True, label=self.cat_summary_group_name
        )
        return cat_summary_group

    def view_plotter_group_builder(self):
        """ Build view element for the plotter tool.
        """
        plotter_group = VGroup(
            Item("plotter", editor=InstanceEditor(), show_label=False,
                 style="custom"),
            label=self.plotting_group_label
        )
        return plotter_group

    # Public interface --------------------------------------------------------

    def destroy(self):
        """ Clean up resources.
        """
        if self._control_popup:
            self._control_popup.dispose()

    # Traits listeners --------------------------------------------------------

    def _open_column_controls_fired(self):
        """ Pop-up a new view on the column list control.
        """
        if self._control_popup and self._control_popup.control:
            # If there is an existing window, bring it in focus:
            # Discussion: https://stackoverflow.com/questions/2240717/in-qt-how-do-i-make-a-window-be-the-current-window  # noqa
            self._control_popup.control._mw.activateWindow()
            return

        # Before viewing self with a simplified view, make sure the original
        # view editors are collected so they can be modified when the controls
        # are used:
        if not self._df_editors:
            self._collect_df_editors()

        view = self.view_klass(
            self.view_data_control_group_builder(force_visible=True),
            buttons=[OKButton],
            width=600, resizable=True,
            title="Control visible columns"
        )
        # WARNING: this will modify the info object the view points to!
        self._control_popup = self.edit_traits(view=view, kind="live")

    def _shuffle_button_fired(self):
        self.model.shuffle_filtered_df()

    def _apply_filter_button_fired(self):
        flt = self.model.filter_exp
        msg = f"Applying filter {flt}."
        logger.log(ACTION_LEVEL, msg)

        self.model.recompute_filtered_df()

    def _pop_out_filter_button_fired(self):
        if not self.filter_editor_cols:
            # if there are no included columns, then use all categorical cols
            df = self.model.source_df
            cat_df = df.select_dtypes(include=CATEGORICAL_COL_TYPES)
            self.filter_editor_cols = list(cat_df.columns)
        filter_editor = FilterExpressionEditorView(
            expr=self.model.filter_exp, view_klass=self.view_klass,
            source_df=self.model.source_df,
            included_cols=self.filter_editor_cols)
        ui = filter_editor.edit_traits(kind="livemodal")
        if ui.result:
            self.model.filter_exp = filter_editor.expr
            self.apply_filter_button = True

    def _manage_filter_button_fired(self):
        """ TODO: review if replacing the copy by a deepcopy or removing the
             copy altogether would help traits trigger listeners correctly
        """
        msg = "Opening filter manager."
        logger.log(ACTION_LEVEL, msg)

        # Make a copy of the list of filters so the model can listen to changes
        # even if only a field of an existing filter is modified:
        filter_manager = FilterExpressionManager(
            known_filter_exps=copy(self.model.known_filter_exps),
            mode="manage", view_klass=self.view_klass
        )
        ui = filter_manager.edit_traits(kind="livemodal")
        if ui.result:
            # FIXME: figure out why this simpler assignment doesn't trigger the
            #  traits listener on the model when changing a FilterExpression
            #  attribute:
            # self.model.known_filter_exps = filter_manager.known_filter_exps

            self.model.known_filter_exps = [
                FilterExpression(name=e.name, expression=e.expression) for e in
                filter_manager.known_filter_exps
            ]

    def _load_filter_button_fired(self):
        filter_manager = FilterExpressionManager(
            known_filter_exps=self.model.known_filter_exps,
            mode="load", view_klass=self.view_klass
        )
        ui = filter_manager.edit_traits(kind="livemodal")
        if ui.result:
            selection = filter_manager.selected_expression
            self.model.filter_exp = selection.expression

    def _save_filter_button_fired(self):
        exp = self.model.filter_exp
        if exp in [e.expression for e in self.model.known_filter_exps]:
            return

        expr = FilterExpression(name=exp, expression=exp)
        self.model.known_filter_exps.append(expr)

    def _show_more_button_fired(self):
        self.model.num_displayed_rows += self.model.num_display_increment

    def _show_all_button_fired(self):
        self.model.num_displayed_rows = -1

    @on_trait_change("model:selected_data_in_plotter_updated", post_init=True)
    def warn_if_selection_hidden(self):
        """ Pop up warning msg if some of the selected rows aren't displayed.
        """
        if not self.warn_if_sel_hidden:
            return

        if not self.model.selected_idx:
            return

        truncated = len(self.model.displayed_df) < len(self.model.filtered_df)
        max_displayed = self.model.displayed_df.index.max()
        some_selection_hidden = max(self.model.selected_idx) > max_displayed
        if truncated and some_selection_hidden:
            warning(None, self.hidden_selection_msg, "Hidden selection")

    @on_trait_change("visible_columns[]", post_init=True)
    def update_filtered_df_on_columns(self):
        """ Just show the columns that are set to visible.

        Notes
        -----
        We are not modifying the filtered data because if we remove a column
        and then bring it back, the adapter breaks because it is missing data.
        Breakage happen when removing a column if the model is changed first,
        or when bring a column back if the adapter column list is changed
        first.
        """
        if not self.info.initialized:
            return

        if not self._df_editors:
            self._collect_df_editors()

        # Rebuild the column list (col name, column id) for the tabular
        # adapter:
        all_visible_cols = [(col, col) for col in self.visible_columns]

        df = self.model.source_df
        cat_dtypes = self.model.categorical_dtypes
        summarizable_df = df.select_dtypes(exclude=cat_dtypes)
        summary_visible_cols = [(col, col) for col in self.visible_columns
                                if col in summarizable_df.columns]

        for df_name, cols in zip(["displayed_df", "summary_df"],
                                 [all_visible_cols, summary_visible_cols]):
            df = getattr(self.model, df_name)
            index_name = df.index.name
            if index_name is None:
                index_name = ''

            # This grabs the corresponding _DataFrameEditor (not the editor
            # factory) which has access to the adapter object:
            editor = self._df_editors[df_name]
            editor.adapter.columns = [(index_name, 'index')] + cols

    def _collect_df_editors(self):
        for df_name in ["displayed_df", "summary_df"]:
            try:
                # This grabs the corresponding _DataFrameEditor (not the editor
                # factory) which has access to the adapter object:
                self._df_editors[df_name] = getattr(self.info, df_name)
            except Exception as e:
                msg = "Error trying to collect the tabular adapter: {}"
                logger.error(msg.format(e))

    def _plotter_launcher_fired(self):
        """ Pop up plot manager view. Only when self.plotter_layout="popup".
        """
        self.plotter.edit_traits(kind="livemodal")

    def _data_exporter_fired(self):
        filepath = request_csv_file(action="save as")
        if filepath:
            self.model.filtered_df.to_csv(filepath)
            open_file(filepath)

    def _summary_exporter_fired(self):
        filepath = request_csv_file(action="save as")
        if filepath:
            self.model.summary_df.to_csv(filepath)
            open_file(filepath)

    # Traits property getters/setters -----------------------------------------

    def _get__known_expr(self):
        return {e.expression for e in self.model.known_filter_exps}

    @cached_property
    def _get_truncation_msg(self):
        num_displayed_rows = self.model.num_displayed_rows
        return self.truncation_msg_template.format(num_displayed_rows)

    @cached_property
    def _get__many_columns(self):
        # Many columns means more than 2 columns:
        return len(self.model.column_list) > 2 * self.max_names_per_column

    # Traits initialization methods -------------------------------------------

    def _plotter_default(self):
        if self.include_plotter:
            if self.model.plot_manager_list:
                if len(self.model.plot_manager_list) > 1:
                    num_plotters = len(self.model.plot_manager_list)
                    msg = "Model contains {} plot manager, but only " \
                          "initializing the Analyzer view with the first " \
                          "plot manager available.".format(num_plotters)
                    logger.warning(msg)

                plot_manager = self.model.plot_manager_list[0]
            else:
                plot_manager = DataFramePlotManager(
                    data_source=self.model.filtered_df,
                    source_analyzer=self.model,
                    **self.plotter_kw
                )

            view = DataFramePlotManagerView(model=plot_manager,
                                            view_klass=self.view_klass)
            return view

    def _formats_default(self):
        if self.display_precision < 0:
            return '%s'
        else:
            formats = {}
            float_format = '%.{}g'.format(self.display_precision)
            for col in self.model.source_df.columns:
                col_dtype = self.model.source_df.dtypes[col]
                if np.issubdtype(col_dtype, np.number):
                    formats[col] = float_format
                else:
                    formats[col] = '%s'

            return formats

    def _visible_columns_default(self):
        return self.model.column_list

    def _hidden_selection_msg_default(self):
        msg = "The displayed data is truncated and some of the selected " \
              "rows isn't displayed in the data table."
        return msg

    def _summary_section_title_default(self):
        if len(self.model.summary_categorical_df) == 0:
            return "Data summary"
        else:
            return "Numerical data summary"

    def _fonts_default(self):
        return "{} {}".format(self.font_name, self.font_size)
class ShorthairCat(HasTraits):
    '''
    所有拥有traits属性的类都需要从HasTraits类继承
    '''
    density_filter = Range(0.0,1.0,1.0)

    calculate_button = ToolbarButton('Calculate')

    initial_button = Button('initialize')
    animate_button = Button('animate')

    # The scene model.
    scene  = Instance(MlabSceneModel,())#此处进行了初始化
    scene0 = Instance(MlabSceneModel,())#位移场景
    scene1 = Instance(MlabSceneModel,())#应力场景
    scene2 = Instance(MlabSceneModel,())#应变场景
    scene3 = Instance(MlabSceneModel,())#密度场景
    scene4 = Instance(MlabSceneModel,())#动图场景

    plot = Instance(PipelineBase)#生成动画的实例

    # The mayavi engine view.
    engine_view = Instance(EngineView)

    # The current selection in the engine tree view.
    current_selection = Property

    ######################
    main_view = View(
                Group(
                              Group(HSplit(HSplit(VSplit(
                              Item(name='engine_view',
                                   style='custom',
                                   resizable=True,
                                   height =500,
                                   width = 200,
                                   show_label=False

                                   ),
                                     ),
                              Item(name='current_selection',
                                   editor=InstanceEditor(),
                                   enabled_when='current_selection is not None',
                                   style='custom',
                                   resizable = True,
                                   height = 500,
                                   width = 200,
                                   springy=True,
                                   show_label=False),
                                   )),label = 'Settings',show_border = False),
                              Group(
                                  Group(

                                      Item(name = 'density_filter',editor = RangeEditor()),
                                      '_',
                                      HSplit(
                                      Item('initial_button', show_label=False),

                                      Item('calculate_button', show_label=False),
                                      Item('animate_button', show_label=False))

                                  ),
                                  Group(
                                       Item(name='scene',
                                            editor=SceneEditor(),
                                            show_label=False,
                                            resizable=True,
                                            springy = True,
                                            height=600,
                                            width=600,
                                            label = 'mesh'
                                            ),
                                      Item(name='scene0',
                                           editor=SceneEditor(),
                                           show_label=False,
                                           resizable=True,
                                           springy=True,
                                           height=600,
                                           width=600,
                                           label='displacement'
                                           ),
                                       Item(name='scene1',
                                            editor=SceneEditor(),
                                            show_label=False,
                                            resizable=True,
                                            springy=True,
                                            height=600,
                                            width=600,
                                            label = 'stress'
                                            ),
                                       Item(name='scene2',
                                            editor=SceneEditor(),
                                            show_label=False,
                                            resizable=True,
                                            springy=True,
                                            height=600,
                                            width=600,
                                            label = 'strain'
                                            ),
                                      Item(name='scene3',
                                           editor=SceneEditor(),
                                           show_label=False,
                                           resizable=True,
                                           springy=True,
                                           height=600,
                                           width=600,
                                           label='density'
                                           ),
                                      Item(name='scene4',
                                           editor=SceneEditor(),
                                           show_label=False,
                                           resizable=True,
                                           springy=True,
                                           height=600,
                                           width=600,
                                           label='animating'
                                           ),
                                      layout = 'tabbed'),

                                  orientation = 'vertical'),
                    orientation = 'horizontal'
                ),
                height = 600,
                width = 760,
                resizable=True,
                # scrollable=True,
                title = 'ShorthairCat',
                )

    #**traits 表示传入参数的个数不确定
    def __init__(self,type,r,penal,move,e,nu,volfac,**traits):

        HasTraits.__init__(self, **traits)
        self.scene.mayavi_scene.name = 'Geometry'
        self.scene.foreground = (1,170/255,0)
        self.scene0.mayavi_scene.name = 'Displacement'
        self.scene1.mayavi_scene.name = 'Stress'
        self.scene2.mayavi_scene.name = 'Strain'
        self.scene3.mayavi_scene.name = 'Density'
        self.scene4.mayavi_scene.name = 'Animate'

        #初始化enine_view
        self.engine_view = EngineView(engine=self.scene.engine)

        #对current_selection 进行动态监听,如果current_selection的值发生变化就调用 self._selection_change

        self.scene.engine.on_trait_change(self._selection_change,name = 'current_selection')
        self.simp_solver = None
        self.type = type
        self.r = r
        self.penal = penal
        self.move = move
        self.e = e
        self.nu = nu
        self.volfac = volfac
        self.address = 'H:\GitHub\Topology-optimization-of-structure-via-simp-method'
        self.i = 1
    def _initial_button_fired(self):
        self.initial_thread = threading.Thread(target = self._initial,args=(),name='Thread-1')
        self.initial_thread.daemon = True
        self.initial_thread.start()

    def _initial(self):
        global_variable.hyperparameter(r=self.r,move=self.move,e=self.e,penal=self.penal,nu=self.nu,volfac=self.volfac)
        global_variable.initialize_global_variable(type =self.type)
        self.simp_solver = Simp()
        self._mayavi()
        self.simp_solver.on_trait_change(self._update_vtkdatasource,name = 'loop')
        self.simp_solver.on_trait_change(self._save_fig, name='loop',dispatch = 'ui')
    def _save_fig(self):
        path = 'H:\GitHub\Topology-optimization-of-structure-via-simp-method\\fig\\'
        fname = path + 'density' + str(self.simp_solver.loop) + '.png'
        self.scene3.mayavi_scene.scene.save(fname)





    def _calculate_button_fired(self):
        #监听loop,一改变立刻更新曲线,同时建立background thread ,在后台进行有限元计算
        # self.simp_solver.on_trait_change(self._plot_convergence_curve, name='loop', dispatch='new')#TODO 发现如果用dispatch = 'ui' 有很大几率卡死,但是这个模式会报错,不过不影响使用
        #self.simp_solver.on_trait_change(self._plot,name = 'loop')

        self.computation_thread = threading.Thread(target=self.simp_solver.simp,args=(),name= 'Thread-2')
        self.computation_thread.daemon = True
        self.computation_thread.start()

        self.plot_thread = threading.Thread(target = self._plot_convergence_curve,args = (),name = 'Thread-3')
        self.plot_thread.daemon = True
        self.plot_thread.start()




    def _animate_button_fired(self):
        #创建一个background thread 不停的显示动画
        animate_thread = threading.Thread(target= self._animate(),args=())
        animate_thread.daemon = True
        animate_thread.start()

    # 静态监听密度过滤器
    def _density_filter_changed(self):
        print('the density is :',self.density_filter)
        self.simp_solver.resultdata.unstrgrid_density = self.simp_solver.resultdata.generate_unstrgrid_mesh(self.density_filter)
        self.simp_solver.resultdata.update_unstrgrid_density(self.simp_solver.resultdata.density)
        self.simp_solver.resultdata.vtkdatasource_density.data = self.simp_solver.resultdata.unstrgrid_density
        self.simp_solver.resultdata.vtkdatasource_density.update()

        self.simp_solver.resultdata.unstrgrid_stress = self.simp_solver.resultdata.generate_unstrgrid_mesh(self.density_filter)
        self.simp_solver.resultdata.update_unstrgrid_stress(self.simp_solver.resultdata.stress)
        self.simp_solver.resultdata.vtkdatasource_stress.data = self.simp_solver.resultdata.unstrgrid_stress
        self.simp_solver.resultdata.vtkdatasource_stress.update()


    #初始化场景
    def _mayavi(self):
        """Shows how you can generate data using mayavi instead of mlab."""
        print('updating mayavi')

        e = self.scene.engine

        #网格scene配置
        e.current_scene = self.scene.mayavi_scene
        e.add_source(self.simp_solver.resultdata.vtkdatasource_mesh)
        e.add_module(Surface(name = 'mesh_wireframe'))
        e.current_scene.children[0].children[0].children[0].actor.property.representation = 'wireframe'
        e.current_scene.children[0].children[0].children[0].actor.property.color = (0,0,0)
        e.current_scene.children[0].children[0].children[0].actor.property.line_width = 1.0
        e.add_module(Surface(name='mesh_solid'))

        #位移scene配置
        e.current_scene = self.scene0.mayavi_scene
        e.add_source(self.simp_solver.resultdata.vtkdatasource_displacement)
        e.add_module(Surface(name = 'displacement'))
        self.scene.engine.current_scene.children[0].children[0].children[0].enable_contours = True
        self.scene.engine.current_scene.children[0].children[0].children[0].contour.filled_contours = True
        self.scene.engine.current_scene.children[0].children[0].children[0].module_manager.scalar_lut_manager.show_legend = True

        #应力scene配置
        e.current_scene = self.scene1.mayavi_scene
        e.add_source(self.simp_solver.resultdata.vtkdatasource_stress)
        e.add_module(Surface(name = 'stress'))
        self.scene.engine.current_scene.children[0].children[0].children[0].enable_contours = True
        self.scene.engine.current_scene.children[0].children[0].children[0].contour.filled_contours = True
        self.scene.engine.current_scene.children[0].children[0].children[0].module_manager.scalar_lut_manager.show_legend = True

        #应变scene配置
        e.current_scene = self.scene2.mayavi_scene
        e.add_source(self.simp_solver.resultdata.vtkdatasource_strain)
        e.add_module(Surface(name = 'strain'))
        self.scene.engine.current_scene.children[0].children[0].children[0].enable_contours = True
        self.scene.engine.current_scene.children[0].children[0].children[0].contour.filled_contours = True
        self.scene.engine.current_scene.children[0].children[0].children[0].module_manager.scalar_lut_manager.show_legend = True

        #密度scene配置
        e.current_scene = self.scene3.mayavi_scene
        e.add_source(self.simp_solver.resultdata.vtkdatasource_density)
        e.add_module(Surface(name = 'density'))
        self.scene.engine.current_scene.children[0].children[0].children[0].module_manager.scalar_lut_manager.show_legend = True
     
       
    def _update_vtkdatasource(self,old,new):
        self.simp_solver.loop

        filter = 0
        print('updating vtkdatasource')
        if 0< self.simp_solver.loop < 10:
            filter = 0.85
        if self.simp_solver.loop >= 10:
            filter = 1/np.e**(self.i) + 0.5
            self.i = self.i+1
        self.simp_solver.resultdata.vtkdatasource_displacement.data = self.simp_solver.resultdata.unstrgrid_displacement
        self.simp_solver.resultdata.vtkdatasource_displacement.update()

        self.simp_solver.resultdata.vtkdatasource_stress.data = self.simp_solver.resultdata.unstrgrid_stress
        self.simp_solver.resultdata.vtkdatasource_stress.update()

        self.simp_solver.resultdata.vtkdatasource_strain.data = self.simp_solver.resultdata.unstrgrid_strain
        self.simp_solver.resultdata.vtkdatasource_strain.update()
        self.simp_solver.resultdata.unstrgrid_density = self.simp_solver.resultdata.generate_unstrgrid_mesh(filter=1)
        self.simp_solver.resultdata.update_unstrgrid_density(self.simp_solver.resultdata.density)
        self.simp_solver.resultdata.vtkdatasource_density.data = self.simp_solver.resultdata.unstrgrid_density
        self.simp_solver.resultdata.vtkdatasource_density.update()

        print('updating done')
        print("----------------------")

    #动态监听currentselection
    def _selection_change(self, old, new):
         self.trait_property_changed('current_selection', old, new)

    def _get_current_selection(self):
         return self.scene.engine.current_selection

    def _plot_convergence_curve(self):

        plt.close()  # clf() # 清图  cla() # 清坐标轴 close() # 关窗口

        fig = plt.figure()
        fig.hold(False)
        ax = fig.add_subplot(1, 1, 1)
        ax.axis("auto")
        ax.set_ylabel('Strain_energy')

        ax.set_title('convergence curves of strain energy and volume rate')

        ax1 = ax.twinx()
        ax1.set_ylabel('volume_rate')
        ax1.set_ylim([0,1])
        # ax.xaxis()# 设置图像显示的时候XY轴比例
        plt.grid(True)  # 添加网格
        plt.ion()  # interactive mode on
        try:
            while 1:
                ax.set_xlabel('Iteration:' + str(self.simp_solver.loop))
                ax.plot(self.simp_solver.strain_energy,c='b')
                ax1.plot(self.simp_solver.volume_rate,c = 'g')
                plt.pause(0.5)
                if self.simp_solver.finished:
                    break
            ax.plot(self.simp_solver.strain_energy,c = 'b')
            ax1.plot(self.simp_solver.volume_rate, c='g')
            plt.savefig('Convergence_curve.png')
            plt.pause(36000)

        except Exception as err:
            print(err)
        # plt.plot(self.simp_solver.strain_energy)H:\GitHub\Topology-optimization-of-structure-via-simp-method\Python
        # ylabel = 'strain_energy/iteration: '+str(self.simp_solver.loop)
        # plt.ylabel(ylabel)
        # plt.xlabel('steps')
        # plt.title('convergence curve of strain energy')
        # plt.show()



    def _plot(self):
        pass
    def _animate(self):

        self.scene.engine.current_scene = self.scene4.mayavi_scene
        src = mlab.pipeline.open((self.address+'\density\density_00.vtu'))
        src.play = False
        src.add_module(Surface(name='animate_density'))
        # self.scene.engine.current_scene.children[0].children[0].children[0].enable_contours=True
        # self.scene.engine.current_scene.children[0].children[0].children[0].contour.filled_contours=True
        self.scene.engine.current_scene.children[0].children[0].children[0].module_manager.scalar_lut_manager.show_legend=True