示例#1
0
class Nirvana(HasTraits):
    xcontrol = Instance(XControl)
    csetting = Instance(XSetting)
    figure = Instance(Figure) # 控制绘图控件的Figure对象


    view = View(
        HSplit( #
            Item("figure", editor=MPLFigureEditor(), show_label=False,width=0.85),
            Group(
                Item('xcontrol',style='custom',show_label=False),
                Item('csetting',style='custom',show_label=False),
                show_labels = False, # 组中的所有控件都不显示标签
                layout = 'tabbed',
            ),
            show_labels = False # 组中的所有控件都不显示标签
        ),
        resizable=True, 
        height=0.95, 
        width=0.99,
        buttons=[OKButton,]
    )
    
    def _figure_default(self):
        figure = Figure()
        figure.add_axes([0.05, 0.04, 0.9, 0.92])
        return figure

    def _xcontrol_default(self):
        return XControl(figure=self.figure)

    def _csetting_default(self):
        cs = XSetting(figure=self.figure)
        cs.generate_candidate([base_infos,index_infos,custom_infos])
        return cs
示例#2
0
class Test(HasTraits):

    figure = Instance(Figure, ())

    view = View(Item('figure', editor=MPLFigureEditor(), show_label=False),
                width=400,
                height=300,
                resizable=True)

    def __init__(self):
        super(Test, self).__init__()
        axes = self.figure.add_subplot(111)
        t = linspace(0, 2 * pi, 200)
        axes.plot(
            sin(t) * (1 + 0.5 * cos(11 * t)),
            cos(t) * (1 + 0.5 * cos(11 * t)))
示例#3
0
class MainWindow(HasTraits):
    """ The main window, here go the instructions to create and destroy the application. """
    figure = Instance(Figure)

    panel = Instance(ControlPanel)

    def _figure_default(self):
        figure = Figure()
        figure.add_axes([0.05, 0.04, 0.9, 0.92])
        return figure

    def _panel_default(self):
        return ControlPanel(figure=self.figure)

    view = View(HSplit(
        Item('figure', editor=MPLFigureEditor(), dock='vertical'),
        Item('panel', style="custom"),
        show_labels=False,
    ),
                resizable=True,
                height=0.75,
                width=0.75,
                handler=MainWindowHandler(),
                buttons=NoButtons)
示例#4
0
class Mainwindow(HasTraits):

    #     panel = Instance(ControlPanel)
    mats_eval = Instance(MATSEval)

    fets_eval = Instance(FETS1D52ULRH)

    time_stepper = Instance(TStepper)

    time_loop = Instance(TLoop)

    t_record = Array
    U_record = Array
    F_record = Array
    sf_record = Array
    eps_record = List
    sig_record = List

    figure = Instance(Figure)

    def _figure_default(self):
        figure = Figure()
        return figure

    plot = Button()

    def _plot_fired(self):
        self.draw()
        self.figure.canvas.draw()

    sigma_y = Range(0.2, 1.2)
    E_b = Range(0.05, 0.35)
    K_bar = Range(-0.01, 0.05)

    @on_trait_change('sigma_y, E_b, K_bar')
    def plot(self):
        self.mats_eval.sigma_y = self.sigma_y
        self.mats_eval.E_b = self.E_b
        self.mats_eval.K_bar = self.K_bar
        self.draw()
        self.figure.canvas.draw()

    L_x = Range(5., 15., value=15.)

    @on_trait_change('L_x')
    def plot1(self):
        self.time_stepper.L_x = self.L_x
        self.draw()
        self.figure.canvas.draw()

    ax1 = Property()

    @cached_property
    def _get_ax1(self):
        return self.figure.add_subplot(231)

    ax2 = Property()

    @cached_property
    def _get_ax2(self):
        return self.figure.add_subplot(232)

    ax3 = Property()

    @cached_property
    def _get_ax3(self):
        return self.figure.add_subplot(234)

    ax4 = Property()

    @cached_property
    def _get_ax4(self):
        return self.figure.add_subplot(235)

    ax5 = Property()

    @cached_property
    def _get_ax5(self):
        return self.figure.add_subplot(233)

    ax6 = Property()

    @cached_property
    def _get_ax6(self):
        return self.figure.add_subplot(236)

    def draw(self):
        self.U_record, self.F_record, self.sf_record, self.t_record, self.eps_record, self.sig_record = self.time_loop.eval(
        )
        n_dof = 2 * self.time_stepper.domain.n_active_elems + 1

        slip, bond = self.time_stepper.mats_eval.get_bond_slip()
        self.ax1.cla()
        l_bs, = self.ax1.plot(slip, bond)
        self.ax1.set_title('bond-slip law')

        self.ax2.cla()
        l_po, = self.ax2.plot(self.U_record[:, n_dof], self.F_record[:, n_dof])
        marker_po, = self.ax2.plot(self.U_record[-1, n_dof],
                                   self.F_record[-1, n_dof], 'ro')
        self.ax2.set_title('pull-out force-displacement curve')

        self.ax3.cla()
        X = np.linspace(0, self.time_stepper.L_x, self.time_stepper.n_e_x + 1)
        X_ip = np.repeat(X, 2)[1:-1]
        l_sf, = self.ax3.plot(X_ip, self.sf_record[-1, :])
        self.ax3.set_title('shear flow in the bond interface')

        self.ax4.cla()
        U = np.reshape(self.U_record[-1, :], (-1, 2)).T
        l_u0, = self.ax4.plot(X, U[0])
        l_u1, = self.ax4.plot(X, U[1])
        l_us, = self.ax4.plot(X, U[1] - U[0])
        self.ax4.set_title('displacement and slip')

        self.ax5.cla()
        l_eps0, = self.ax5.plot(X_ip, self.eps_record[-1][:, :, 0].flatten())
        l_eps1, = self.ax5.plot(X_ip, self.eps_record[-1][:, :, 2].flatten())
        self.ax5.set_title('strain')

        self.ax6.cla()
        l_sig0, = self.ax6.plot(X_ip, self.sig_record[-1][:, :, 0].flatten())
        l_sig1, = self.ax6.plot(X_ip, self.sig_record[-1][:, :, 2].flatten())
        self.ax6.set_title('stress')

    time = Range(0.00, 1.02, value=1.02)

    @on_trait_change('time')
    def draw_t(self):
        idx = (np.abs(self.time - self.t_record)).argmin()
        n_dof = 2 * self.time_stepper.domain.n_active_elems + 1

        self.ax2.cla()
        l_po, = self.ax2.plot(self.U_record[:, n_dof], self.F_record[:, n_dof])
        marker_po, = self.ax2.plot(self.U_record[idx, n_dof],
                                   self.F_record[idx, n_dof], 'ro')
        self.ax2.set_title('pull-out force-displacement curve')

        self.ax3.cla()
        X = np.linspace(0, self.time_stepper.L_x, self.time_stepper.n_e_x + 1)
        X_ip = np.repeat(X, 2)[1:-1]
        l_sf, = self.ax3.plot(X_ip, self.sf_record[idx, :])
        self.ax3.set_title('shear flow in the bond interface')

        self.ax4.cla()
        U = np.reshape(self.U_record[idx, :], (-1, 2)).T
        l_u0, = self.ax4.plot(X, U[0])
        l_u1, = self.ax4.plot(X, U[1])
        l_us, = self.ax4.plot(X, U[1] - U[0])
        self.ax4.set_title('displacement and slip')

        self.ax5.cla()
        l_eps0, = self.ax5.plot(X_ip, self.eps_record[idx][:, :, 0].flatten())
        l_eps1, = self.ax5.plot(X_ip, self.eps_record[idx][:, :, 2].flatten())
        self.ax5.set_title('strain')

        self.ax6.cla()
        l_sig0, = self.ax6.plot(X_ip, self.sig_record[idx][:, :, 0].flatten())
        l_sig1, = self.ax6.plot(X_ip, self.sig_record[idx][:, :, 2].flatten())
        self.ax6.set_title('stress')

        self.figure.canvas.draw()

    view = View(
        HSplit(Item('figure',
                    editor=MPLFigureEditor(),
                    dock='vertical',
                    width=0.7,
                    height=0.9),
               Group(Item('mats_eval'), Item('fets_eval'),
                     Item('time_stepper'), Item('time_loop'), Item('sigma_y'),
                     Item('E_b'), Item('K_bar'), Item('L_x'), Item('time')),
               show_labels=False),
        resizable=True,
        height=0.9,
        width=1.0,
    )
示例#5
0
class MainWindow(HasTraits):
    panel = Instance(ControlPanel, ())

    df = DelegatesTo('panel')

    plot_fourier_series = Bool(True)
    plot_data = Bool(True)

    plot_type = Trait('0_plot_xy', {'0_plot_xy':0,
                                  '1_plot_n_coeff':1,
                                  '2_plot_freq_coeff':2,
                                  '3_plot_freq_coeff_abs':3,
                                  '4_plot_freq_energy':4})

    plot_title = Str(enter_set=True, auto_set=False, changed=True)
    label_fsize = Float(15, enter_set=True, auto_set=False, changed=True)
    tick_fsize = Float(15, enter_set=True, auto_set=False, changed=True)
    title_fsize = Float(15, enter_set=True, auto_set=False, changed=True)

    label_default = Bool(True)
    x_label = Str('x', changed=True)
    x_limit_on = Bool(False)
    x_limit = Tuple((0., 1.), changed=True)
    y_label = Str('y', changed=True)
    y_limit_on = Bool(False)
    y_limit = Tuple((0., 1.), changed=True)

    figure = Instance(Figure)
    def _figure_default(self):
        figure = Figure(tight_layout=True)
        figure.add_subplot(111)
        # figure.add_axes([0.15, 0.15, 0.75, 0.75])
        return figure

    @on_trait_change('+changed')
    def _redraw(self):
        self._draw_fired()

    draw = Button
    def _draw_fired(self):
        figure = self.figure
        axes = figure.axes[0]
        axes.clear()
        # self.x_limit = (axes.axis()[0], axes.axis()[1])
        # self.y_limit = (axes.axis()[2], axes.axis()[3])
        df = self.df

        label_fsize = self.label_fsize
        tick_fsize = self.tick_fsize
        title_fsize = self.title_fsize
        if self.plot_type_ == 0:
            if self.plot_data and self.plot_fourier_series == False:
                axes.plot(df.x, df.y, color='blue', label='data')
            elif self.plot_fourier_series and self.plot_data == False:
                axes.plot(df.x, df.y_fourier, color='green', label='fourier')
            else:
                axes.plot(df.x, df.y, color='blue', label='data')
                axes.plot(df.x, df.y_fourier, color='green', label='fourier')
            axes.legend(loc='best')
            axes.grid()
            if self.label_default:
                self.x_label = 'x'
                self.y_label = 'y'
            axes.set_xlabel(self.x_label, fontsize=label_fsize)
            axes.set_ylabel(self.y_label, fontsize=label_fsize)
            axes.set_title(self.plot_title, fontsize=title_fsize)
            if self.x_limit_on:
                axes.set_xlim(self.x_limit)
            if self.y_limit_on:
                axes.set_ylim(self.y_limit)
            p.setp(axes.get_xticklabels(), fontsize=tick_fsize, position=(0, -.01))  # position - posun od osy x
            p.setp(axes.get_yticklabels(), fontsize=tick_fsize)

        if self.plot_type_ == 1:
            axes.vlines(df.N_arr - 0.05, [0], df.cos_coeff, color='blue', label='cos')
            axes.vlines(df.N_arr + 0.05, [0], df.sin_coeff, color='green', label='sin')
            axes.legend(loc='best')
            axes.grid()
            if self.label_default:
                self.x_label = 'n'
                self.y_label = 'coeff'
            axes.set_title(self.plot_title, fontsize=title_fsize)
            axes.set_xlabel(self.x_label, fontsize=label_fsize)
            axes.set_ylabel(self.y_label, fontsize=label_fsize)
            if self.x_limit_on:
                axes.set_xlim(self.x_limit)
            if self.y_limit_on:
                axes.set_ylim(self.y_limit)
            p.setp(axes.get_xticklabels(), fontsize=tick_fsize, position=(0, -.01))  # position - posun od osy x
            p.setp(axes.get_yticklabels(), fontsize=tick_fsize)

        if self.plot_type_ == 2:
            axes.vlines(df.freq, [0], df.cos_coeff, color='blue', label='cos')
            axes.vlines(df.freq, [0], df.sin_coeff, color='green', label='sin')
            axes.legend(loc='best')
            axes.grid()
            if self.label_default:
                self.x_label = 'freq'
                self.y_label = 'coeff'
            axes.set_title(self.plot_title, fontsize=title_fsize)
            axes.set_xlabel(self.x_label, fontsize=label_fsize)
            axes.set_ylabel(self.y_label, fontsize=label_fsize)
            if self.x_limit_on:
                axes.set_xlim(self.x_limit)
            if self.y_limit_on:
                axes.set_ylim(self.y_limit)
            p.setp(axes.get_xticklabels(), fontsize=tick_fsize, position=(0, -.01))  # position - posun od osy x
            p.setp(axes.get_yticklabels(), fontsize=tick_fsize)

        if self.plot_type_ == 3:
            axes.vlines(df.freq, [0], np.abs(df.cos_coeff), color='blue', label='cos')
            axes.vlines(df.freq, [0], np.abs(df.sin_coeff), color='green', label='sin')
            axes.legend(loc='best')
            axes.set_title(self.plot_title, fontsize=title_fsize)
            if self.label_default:
                self.x_label = 'freq'
                self.y_label = 'coeff'
            axes.set_xlabel(self.x_label, fontsize=label_fsize)
            axes.set_ylabel(self.y_label, fontsize=label_fsize)
            y_val = np.abs(np.hstack((df.cos_coeff, df.sin_coeff))).max()
            axes.set_ybound((0, y_val * 1.05))
            if self.x_limit_on:
                axes.set_xlim(self.x_limit)
            if self.y_limit_on:
                axes.set_ylim(self.y_limit)
            p.setp(axes.get_xticklabels(), fontsize=tick_fsize, position=(0, -.01))  # position - posun od osy x
            p.setp(axes.get_yticklabels(), fontsize=tick_fsize)

        if self.plot_type_ == 4:
            axes.plot(df.freq, df.energy, 'k-', label='energ')
            axes.legend(loc='best')
            axes.set_title(self.plot_title, fontsize=title_fsize)
            if self.label_default:
                self.x_label = 'freq'
                self.y_label = 'energy'
            axes.set_xlabel(self.x_label, fontsize=label_fsize)
            axes.set_ylabel(self.y_label, fontsize=label_fsize)
            y_val = np.abs(np.hstack((df.cos_coeff, df.sin_coeff))).max()
            axes.set_ybound((0, y_val * 1.05))
            if self.x_limit_on:
                axes.set_xlim(self.x_limit)
            if self.y_limit_on:
                axes.set_ylim(self.y_limit)
            p.setp(axes.get_xticklabels(), fontsize=tick_fsize, position=(0, -.01))  # position - posun od osy x
            p.setp(axes.get_yticklabels(), fontsize=tick_fsize)

        self.figure.canvas.draw()

    traits_view = View(HSplit(
                              Tabbed(
                                  Group(
                                        Item('panel@', show_label=False, id='fourier.panel'),
                                          Group(
                                                Item('plot_data'),
                                                Item('plot_fourier_series'),
                                                '_',
                                                Item('plot_type'),
                                                label='plot options',
                                                show_border=True,
                                                id='fourier.plot_options'
                                                ),
                                        UItem('draw', label='calculate and draw'),
                                         label='fourier',
                                         id='fourier.fourier'
                                         ),
                                 Group(
                                        Item('plot_title', label='title'),
                                        Item('title_fsize', label='title fontsize'),
                                        Item('label_fsize', label='label fontsize'),
                                        Item('tick_fsize', label='tick fontsize'),
                                        Item('_'),
                                        Item('x_limit_on'),
                                        Item('x_limit', label='x limit - plot', enabled_when='x_limit_on'),
                                        Item('_'),
                                        Item('y_limit_on'),
                                        Item('y_limit', label='y limit - plot', enabled_when='y_limit_on'),
                                        Item('_'),
                                        Item('label_default'),
                                        Item('x_label', enabled_when='label_default==False'),
                                        Item('y_label', enabled_when='label_default==False'),
                                        label='plot settings',
                                        show_border=True,
                                        id='fourier.plot_settings',
                                    dock='tab',
                                  ),
                                ),
                              VGroup(
                                    Item('figure', editor=MPLFigureEditor(),
                                    resizable=True, show_label=False),
                                    label='Plot sheet',
                                    id='fourier.figure',
                                    dock='tab',
                                    ),
                                 ),
                        title='Fourier series',
                        id='main_window.view',
                        resizable=True,
                        width=0.7,
                        height=0.7,
                        buttons=[OKButton]
                       )
示例#6
0
class Graph(HasTraits):
    """
    绘图组件,包括左边的数据选择控件和右边的绘图控件
    """
    name = Str # 绘图名,显示在标签页标题和绘图标题中
    data_source = Instance(DataSource) # 保存数据的数据源
    figure = Instance(Figure) # 控制绘图控件的Figure对象
    selected_xaxis = Str # X轴所用的数据名
    selected_items = List # Y轴所用的数据列表

    clear_button = Button(u"清除") # 快速清除Y轴的所有选择的数据

    view = View(
        HSplit( # HSplit分为左右两个区域,中间有可调节宽度比例的调节手柄
            # 左边为一个组
            VGroup(
                Item("name"),   # 绘图名编辑框
                Item("clear_button"), # 清除按钮
                Heading(u"X轴数据"),  # 静态文本
                # X轴选择器,用EnumEditor编辑器,即ComboBox控件,控件中的候选数据从
                # data_source的names属性得到
                Item("selected_xaxis", editor=
                    EnumEditor(name="object.data_source.names", format_str=u"%s")),
                Heading(u"Y轴数据"), # 静态文本
                # Y轴选择器,由于Y轴可以多选,因此用CheckBox列表编辑,按两列显示
                Item("selected_items", style="custom", 
                     editor=CheckListEditor(name="object.data_source.names", 
                            cols=2, format_str=u"%s")),
                show_border = True, # 显示组的边框
                scrollable = True,  # 组中的控件过多时,采用滚动条
                show_labels = False # 组中的所有控件都不显示标签
            ),
            # 右边绘图控件
            Item("figure", editor=MPLFigureEditor(), show_label=False, width=600)
        )        
    )

    def _name_changed(self):
        """
        当绘图名发生变化时,更新绘图的标题
        """
        axe = self.figure.axes[0]
        axe.set_title(self.name)
        self.figure.canvas.draw()

    def _clear_button_fired(self):
        """
        清除按钮的事件处理
        """
        self.selected_items = []
        self.update()

    def _figure_default(self):
        """
        figure属性的缺省值,直接创建一个Figure对象
        """
        figure = Figure()
        figure.add_axes([0.1, 0.1, 0.85, 0.80]) #添加绘图区域,四周留有边距
        return figure

    def _selected_items_changed(self):
        """
        Y轴数据选择更新
        """
        self.update()

    def _selected_xaxis_changed(self):
        """
        X轴数据选择更新
        """    
        self.update()

    def update(self):
        """
        重新绘制所有的曲线
        """    
        axe = self.figure.axes[0]
        axe.clear()
        try:
            xdata = self.data_source.data[self.selected_xaxis]
        except:
            return 
        for field in self.selected_items:
            axe.plot(xdata, self.data_source.data[field], label=field)
        axe.set_xlabel(self.selected_xaxis)
        axe.set_title(self.name)
        axe.legend()
        self.figure.canvas.draw()
示例#7
0
		ax = self.figure.add_subplot(111)
		ax.imshow(self.imgarray, cmap = cm.gist_gray)
		self.figure.canvas.draw()

	def do_threshold(self):
		"""
		threshold image
		"""
		thrsh = threshold_otsu(self.imgarray)
		self.imgarray = self.imgarray>thrsh

	def do_swirl(self):
		"""
		swirl image
		"""
		swirled = swirl(self.imgarray, rotation=0, strength=10, radius=150, order=2)
		self.imgarray = swirled

	def do_rotate(self):
		""" if you rotate after thresholding, then the image looks weird"""
		rotated = ndimage.rotate(self.imgarray, 90)
		self.imgarray=rotated

	def do_undo(self):
		""" 1 step undo """
		self.imgarray = self.prev_array

## configure gui appearance
view1 = View(Item('query',width = -250,resizable=True,label = "Query String",full_size=True), Item('refresh',resizable = True, label = "Refresh Image"), Item('url',width = -250, resizable = True, padding = 2, label = "Image URL",full_size=True), Item('figure', show_label = False, width = 600, height = 600, resizable = True, editor = MPLFigureEditor()), Item('thresh', label = 'Threshold',show_label=False), Item('swirled',resizable=False,label = 'Swirl',show_label=False), Item('rotate',resizable=False,label="Rotate 90",show_label=False),Item('undo',resizable=False,label="UNDO",show_label=False))

i = ImgManip(); i.configure_traits(view = view1)
示例#8
0
class MPLPlot(Viewer):
    """
      A plot, cointains code to display using a Matplotlib figure and to update itself
      dynamically from a Variables instance (which must be passed in on initialisation).
      The function plotted is calculated using 'expr' which should also be set on init
      and can be any python expression using the variables in the pool.
  """
    name = Str('MPL Plot')
    figure = Instance(Figure, ())
    expr = Str

    x_max = Float
    x_max_auto = Bool(True)
    x_min = Float
    x_min_auto = Bool(True)
    y_max = Float
    y_max_auto = Bool(True)
    y_min = Float
    y_min_auto = Bool(True)

    scroll = Bool(True)
    scroll_width = Float(300)

    legend = Bool(False)
    legend_pos = Enum('upper left', 'upper right', 'lower left', 'lower right',
                      'right', 'center left', 'center right', 'lower center',
                      'upper center', 'center', 'best')

    traits_view = View(Item(name='name', label='Plot name'),
                       Item(name='expr',
                            label='Expression(s)',
                            editor=TextEditor(enter_set=True, auto_set=False)),
                       Item(label='Use commas\nfor multi-line plots.'),
                       HGroup(
                           Item(name='legend', label='Show legend'),
                           Item(name='legend_pos', show_label=False),
                       ),
                       VGroup(HGroup(Item(name='x_max', label='Max'),
                                     Item(name='x_max_auto', label='Auto')),
                              HGroup(Item(name='x_min', label='Min'),
                                     Item(name='x_min_auto', label='Auto')),
                              HGroup(
                                  Item(name='scroll', label='Scroll'),
                                  Item(name='scroll_width',
                                       label='Scroll width'),
                              ),
                              label='X',
                              show_border=True),
                       VGroup(HGroup(Item(name='y_max', label='Max'),
                                     Item(name='y_max_auto', label='Auto')),
                              HGroup(Item(name='y_min', label='Min'),
                                     Item(name='y_min_auto', label='Auto')),
                              label='Y',
                              show_border=True),
                       title='Plot settings',
                       resizable=True)

    view = View(Item(name='figure', editor=MPLFigureEditor(),
                     show_label=False),
                width=400,
                height=300,
                resizable=True)

    legend_prop = matplotlib.font_manager.FontProperties(size=8)

    def start(self):
        # Init code creates an empty plot to be updated later.
        axes = self.figure.add_subplot(111)
        axes.plot([0], [0])

    def update(self):
        """
        Update the plot from the Variables instance and make a call to wx to
        redraw the figure.
    """
        axes = self.figure.gca()
        lines = axes.get_lines()

        if lines:
            exprs = self.get_exprs()
            if len(exprs) > len(lines):
                for i in range(len(exprs) - len(lines)):
                    axes.plot([0], [0])
                lines = axes.get_lines()

            max_xs = max_ys = min_xs = min_ys = 0

            for n, expr in enumerate(exprs):
                first = 0
                last = None
                if self.scroll and self.x_min_auto and self.x_max_auto:
                    first = -self.scroll_width
                if not self.x_min_auto:
                    first = int(self.x_min)
                if not self.x_max_auto:
                    last = int(self.x_max) + 1

                ys = self.variables.new_expression(expr).get_array(first, last)

                if len(ys) != 0:
                    xs = self.variables.new_expression('sample_num').get_array(
                        first, last)
                else:
                    xs = [0]
                    ys = [0]

                if len(xs) != len(ys):
                    print "MPL Plot: x and y arrays different sizes!!! Ignoring (but fix me soon)."
                    return

                lines[n].set_xdata(xs)
                lines[n].set_ydata(ys)

                max_xs = max_xs if (max(xs) < max_xs) else max(xs)
                max_ys = max_ys if (max(ys) < max_ys) else max(ys)
                min_xs = min_xs if (min(xs) > min_xs) else min(xs)
                min_ys = min_ys if (min(ys) > min_ys) else min(ys)

            if self.x_max_auto:
                self.x_max = max_xs
            if self.x_min_auto:
                if self.scroll and self.x_max_auto:
                    scroll_x_min = self.x_max - self.scroll_width
                    self.x_min = scroll_x_min if (scroll_x_min >= 0) else 0
                else:
                    self.x_min = min_xs
            if self.y_max_auto:
                self.y_max = max_ys
            if self.y_min_auto:
                self.y_min = min_ys

            axes.set_xbound(upper=self.x_max, lower=self.x_min)
            axes.set_ybound(upper=self.y_max * 1.1, lower=self.y_min * 1.1)

            self.draw_plot()

    def get_exprs(self):
        return self.expr.split(',')

    def add_expr(self, expr):
        if self.expr == '' or self.expr[:-1] == ',':
            self.expr += expr
        else:
            self.expr += ',' + expr

    def draw_plot(self):
        if self.figure.canvas:
            CallAfter(self.figure.canvas.draw)

    @on_trait_change('legend_pos')
    def update_legend_pos(self, old, new):
        """ Move the legend, calls update_legend """
        self.update_legend(None, None)

    @on_trait_change('legend')
    def update_legend(self, old, new):
        """ Called when we change the legend display """
        axes = self.figure.gca()
        lines = axes.get_lines()
        exprs = self.get_exprs()

        if len(exprs) >= 1 and self.legend:
            axes.legend(lines[:len(exprs)],
                        exprs,
                        loc=self.legend_pos,
                        prop=self.legend_prop)
        else:
            axes.legend_ = None

        self.draw_plot()
示例#9
0
class MainWindow(HasTraits):
    def _pck_(self, action, f=mainpck):
        if action == 'load':
            try:
                fpck = open(f, "rb")
                print 'Loading panel from %s' % mainpck
                self.seqs = pickle.load(fpck)
            except:
                print "Loading Fail"
                return
        if action == 'save':
            print 'Saving panel to %s' % mainpck
            fpck = open(f, "w+b")
            pickle.dump(self.seqs, fpck)
        fpck.close()

    figure = Figure()

    seqs = List([sequence(name='S0', file=default_file)])

    add = Button("Add Sequence")

    plot = Button("Plot")

    data_digi = List()

    data_analog = List()

    data_digi_time = List()

    data_digi_names = List()

    data_analog_time = List()

    data_analog_names = List()

    autorangeY = Bool(True, label="Autorange in Y?")
    plot_rangeY_max = Float(10)
    plot_rangeY_min = Float(0)

    autorange = Bool(True, label="Autorange?")
    plot_range_max = Float(10000)
    plot_range_min = Float()

    def _figure_default(self):

        self.figure = Figure()

        self.figure.add_axes([0.05, 0.04, 0.9, 0.92])

    control_group = Group(
        Item('plot', show_label=False),
        VGroup(
            Item('seqs',
                 style='custom',
                 editor=ListEditor(use_notebook=True,
                                   deletable=True,
                                   dock_style='tab',
                                   page_name='.name')),
            Item('add', show_label=False),
            show_labels=False,
            show_border=True,
            label='seqs',
        ))

    view = View(
        HSplit(
            control_group,
            VGroup(
                Item('figure',
                     editor=MPLFigureEditor(),
                     dock='vertical',
                     show_label=False,
                     width=700),
                HGroup('autorange', 'plot_range_min', spring,
                       'plot_range_max'),
                HGroup('autorangeY', 'plot_rangeY_min', spring,
                       'plot_rangeY_max'))),
        width=1,
        height=0.95,
        resizable=True,
        handler=MainWindowHandler(),
    )

    def _add_fired(self):

        self.seqs.append(
            sequence(name='S%d' % len(self.seqs), file=default_file))

    def _plot_fired(self):

        self.get_data()

        self.image_show()

    #~ def _plot_range_max_changed(self):
    #~ sleep()
    #~ self.image_show()

    #~ def _plot_range_min_changed(self):
    #~ self.image_show()

    def get_data(self):

        self.data_digi = []

        self.data_analog = []

        self.data_digi_time = []

        self.data_digi_names = []

        self.data_analog_time = []

        self.data_analog_names = []

        digi_counter = 1

        for seq in self.seqs:

            for waveform in seq.waveforms:

                for i, channel in enumerate(waveform.channels):

                    if channel in waveform.select_channels:

                        if waveform.name == 'Digital':
                            digi_counter = digi_counter + 1

        digi_counter_2 = 1

        for seq in self.seqs:

            for waveform in seq.waveforms:

                for i, channel in enumerate(waveform.channels):

                    if channel in waveform.select_channels:

                        if waveform.name == 'Digital':

                            self.data_digi_time.append(waveform.time)

                            self.data_digi_names.append(seq.name + '_' +
                                                        waveform.name + '_' +
                                                        channel)

                            length = len(waveform.select_channels)

                            height = 10. / (digi_counter - 1)

                            self.data_digi.append([
                                i * height * 0.8 - digi_counter_2 * height +
                                height * 0.1 for i in waveform.data[i]
                            ])

                            digi_counter_2 = digi_counter_2 + 1

                        else:

                            self.data_analog_time.append(waveform.time)

                            self.data_analog_names.append(seq.name + '_' +
                                                          waveform.name + '_' +
                                                          channel)

                            self.data_analog.append(waveform.data[i])

    def image_clear(self):
        """ Clears canvas 
        """
        self.figure.clf()

        wx.CallAfter(self.figure.canvas.draw)

    def image_show(self):
        """ Plots an image on the canvas
        """

        self.image_clear()

        self.figure.add_axes([0.08, 0.5, 0.7, 0.4])

        analog_axis = self.figure.axes[0]

        self.figure.add_axes([0.08, 0.05, 0.7, 0.4])

        self.figure.axes[1].set_yticks([])

        digi_axis = self.figure.axes[1].twinx()

        digi_ticks = []

        for i, name in enumerate(self.data_digi_names):

            digi_axis.step(self.data_digi_time[i],
                           self.data_digi[i],
                           where='post',
                           label=name)

            digi_ticks.append(-(i + 0.5) * 10. / len(self.data_digi_names))

            digi_axis.axhline(-(i + 1) * 10. / len(self.data_digi_names),
                              color='grey',
                              lw=1.5)

        for i, name in enumerate(self.data_analog_names):

            #~ print name,len(self.data_analog_time[i]),len(self.data_analog[i])
            analog_axis.step(self.data_analog_time[i],
                             self.data_analog[i],
                             where='post',
                             label=name)

        analog_axis.axhline(0, color='black', lw=2)

        analog_axis.legend(bbox_to_anchor=(1.01, 1.01),
                           loc=2,
                           prop={'size': 10})

        #digi_axis.legend(bbox_to_anchor=(1.01, 0.5),loc=2,prop={'size':10})

        analog_axis.set_xlabel('Time(ms)')

        self.figure.axes[1].set_ylabel('TTL')

        analog_axis.set_ylabel('Voltage(V)')

        analog_axis.grid(True)

        if not self.autorangeY:

            analog_axis.set_ylim(self.plot_rangeY_min, self.plot_rangeY_max)

        else:

            axismin = min(analog_axis.get_ylim())

            axismax = max(analog_axis.get_ylim())

            analog_axis.set_ylim(axismin, axismax)

            self.plot_rangeY_max = axismax

            self.plot_rangeY_min = axismin

        if not self.autorange:

            analog_axis.set_xlim(self.plot_range_min, self.plot_range_max)

            digi_axis.set_xlim(self.plot_range_min, self.plot_range_max)

        else:

            axismin = min(analog_axis.get_xlim() + digi_axis.get_xlim())

            axismax = max(analog_axis.get_xlim() + digi_axis.get_xlim())

            analog_axis.set_xlim(axismin, axismax)

            digi_axis.set_xlim(axismin, axismax)

            #~ analog_axis.set_ylim(bottom=0,top=11)

            self.plot_range_max = axismax

            self.plot_range_min = axismin

        digi_axis.set_ylim(bottom=-11, top=0)

        digi_axis.set_yticks(digi_ticks)

        digi_axis.set_yticklabels(self.data_digi_names)

        wx.CallAfter(self.figure.canvas.draw)
class MainWindow(HasTraits):

    #This functio takes care of loading or saving the pck
    def _pck_(self, action, f=mainpck):
        if action == 'load':
            try:

                try:
                    fpck = open(f, "rb")
                    print 'Loading panel from %s ... ' % mainpck
                    self.seqs = pickle.load(fpck)

                    self.autorangeY = pickle.load(fpck)
                    self.plot_rangeY_max = pickle.load(fpck)
                    self.plot_rangeY_min = pickle.load(fpck)

                    self.autorange = pickle.load(fpck)
                    self.plot_range_max = pickle.load(fpck)
                    self.plot_range_min = pickle.load(fpck)
                    fpck.close()
                except:
                    fpck = open(f, "rb")
                    print 'Loading panel from %s ... ' % mainpck
                    self.seqs = pickle.load(fpck)
                    fpck.close()

            except:
                print "Loading Fail"
                return
        if action == 'save':
            print 'Saving panel to %s ...' % mainpck
            fpck = open(f, "w+b")
            pickle.dump(self.seqs, fpck)

            pickle.dump(self.autorangeY, fpck)
            pickle.dump(self.plot_rangeY_max, fpck)
            pickle.dump(self.plot_rangeY_min, fpck)

            pickle.dump(self.autorange, fpck)
            pickle.dump(self.plot_range_max, fpck)
            pickle.dump(self.plot_range_min, fpck)

            fpck.close()

    #Here are the elements of the main window
    figure = Figure()

    global seqct
    seqct = seqct + 1  #increment seq counter
    seqs = List([sequence(name='S%d' % seqct, txtfile=default_file)])

    selectedseq = Instance(sequence)
    index = Int

    add = Button("Add Sequence")
    plot = Button("Plot")

    data_digi = List()
    data_digi_time = List()
    data_digi_names = List()

    data_analog = List()
    data_analog_time = List()
    data_analog_names = List()

    autorangeY = Bool(True, label="Autorange in Y?")
    plot_rangeY_max = Float(10)
    plot_rangeY_min = Float(0)

    autorange = Bool(True, label="Autorange?")
    plot_range_max = Float(10000)
    plot_range_min = Float()

    def _figure_default(self):
        self.figure = Figure()
        self.figure.add_axes([0.05, 0.04, 0.9, 0.92])

    #This group contains the View of the control buttons
    #and the waveforms
    control_group = Group(
        Item('plot', show_label=False),
        VGroup(
            Item('seqs',
                 style='custom',
                 editor=ListEditor(use_notebook=True,
                                   selected='selectedseq',
                                   deletable=True,
                                   dock_style='tab',
                                   page_name='.name')),
            Item('add', show_label=False),
            show_labels=False,
            show_border=True,
            label='seqs',
        ))

    #This is the view of the Main Window, including the
    #control group
    view = View(
        HSplit(
            control_group,
            VGroup(
                Item('figure',
                     editor=MPLFigureEditor(),
                     dock='vertical',
                     show_label=False,
                     width=700),
                HGroup('autorange', 'plot_range_min', spring,
                       'plot_range_max'),
                HGroup('autorangeY', 'plot_rangeY_min', spring,
                       'plot_rangeY_max'))),
        title='Display Sequence',
        width=1,
        height=0.95,
        resizable=True,
        handler=MainWindowHandler(),
    )

    def _selectedseq_changed(self, selectedseq):
        self.index = self.seqs.index(selectedseq)

    #Define action when a new sequence is added
    def _add_fired(self):
        global seqct
        seqct = seqct + 1

        new = copy.deepcopy(self.seqs[self.index])
        new.name = 'S%d' % seqct
        self.seqs.append(new)

    #Define action when the plot button is pressed
    def _plot_fired(self):

        self.get_data()
        self.image_show()

    #~ def _plot_range_max_changed(self):
    #~ sleep()
    #~ self.image_show()

    #~ def _plot_range_min_changed(self):
    #~ self.image_show()

    #Here the data
    def get_data(self):

        self.data_digi = []
        self.data_digi_time = []
        self.data_digi_names = []

        self.data_analog = []
        self.data_analog_time = []
        self.data_analog_names = []

        self.data_physical = []
        self.data_physical_time = []
        self.data_physical_names = []

        #Find out how many digital channels will be plotted
        digi_counter = 1
        for seq in self.seqs:
            if seq.plotme == True:
                for waveform in seq.waveforms:
                    for i, channel in enumerate(waveform.channels):
                        if channel in waveform.select_channels:
                            if waveform.name == 'Digital':
                                digi_counter = digi_counter + 1

        #Setup all the digital & analog channels from the various
        #seqs for plotting
        digi_counter_2 = 1
        for seq in self.seqs:
            if seq.plotme == True:
                #Prepare the physical quantities calculator class
                physical = physics.calc(seq.seq.wfms)
                print "\n--------  GETTING DATA TO PRODUCE PLOT  --------"
                for waveform in seq.waveforms:
                    for i, channel in enumerate(waveform.channels):
                        if channel in waveform.select_channels:

                            if waveform.name == 'Digital':
                                self.data_digi_time.append(waveform.time)
                                self.data_digi_names.append(seq.name + '_' +
                                                            waveform.name +
                                                            '_' + channel)

                                length = len(waveform.select_channels)
                                height = 10. / (digi_counter - 1)

                                self.data_digi.append(  [ i*height*0.8 - digi_counter_2*height + height*0.1  \
                                                          for i in waveform.data[i] ]  )

                                digi_counter_2 = digi_counter_2 + 1

                            elif waveform.name == 'Analog':
                                self.data_analog_names.append(seq.name + '_' +
                                                              waveform.name +
                                                              '_' + channel)
                                self.data_analog_time.append(waveform.time[i])
                                self.data_analog.append(waveform.data[i])

                            elif waveform.name == 'Physical':
                                self.data_physical_names.append(seq.name +
                                                                '_' +
                                                                waveform.name +
                                                                '_' + channel)
                                if channel in seq.calcwfms.keys(
                                ) and not seq.recalculate:
                                    print "\n...Reusing Physical: %s" % channel
                                    dat = seq.calcwfms[channel]
                                else:
                                    print "\n...Calculating Physical: %s" % channel
                                    dat = physical.calculate(channel)
                                    seq.calcwfms[channel] = dat

                                self.data_physical_time.append(dat[0])
                                self.data_physical.append(dat[1])
            seq.recalculate = False

    def image_clear(self):
        """ Clears canvas 
        """
        self.figure.clf()
        wx.CallAfter(self.figure.canvas.draw)

    def image_show(self):
        """ Plots an image on the canvas
        """
        self.image_clear()
        analog_axis = self.figure.add_axes([0.08, 0.5, 0.7, 0.4])

        digi_axis_left = self.figure.add_axes([0.08, 0.05, 0.7, 0.4])
        digi_axis_left.set_yticks([])

        digi_axis = digi_axis_left.twinx()
        digi_ticks = []

        #physical_axis = analog_axis.twinx()

        #Makes the digital plot
        for i, name in enumerate(self.data_digi_names):
            digi_axis.step(self.data_digi_time[i],
                           self.data_digi[i],
                           lw='2.0',
                           where='post',
                           label=name)
            digi_ticks.append(-(i + 0.5) * 10. / len(self.data_digi_names))
            digi_axis.axhline(-(i + 1) * 10. / len(self.data_digi_names),
                              color='grey',
                              lw=1.5)

        #digi_axis.legend(bbox_to_anchor=(1.01, 0.5),loc=2,prop={'size':10})
        digi_axis_left.set_ylabel('TTL')

        #Label the digital waveforms using the ticklabels on the plot
        digi_axis.set_ylim(bottom=-11, top=0)
        digi_axis.set_yticks(digi_ticks)
        digi_axis.set_yticklabels(self.data_digi_names)

        digi_axis_left.get_xaxis().set_minor_locator(
            matplotlib.ticker.AutoMinorLocator())
        digi_axis_left.grid(True, which='both')

        #Makes the analog plot
        for i, name in enumerate(self.data_analog_names):
            analog_axis.step(self.data_analog_time[i],
                             self.data_analog[i],
                             where='post',
                             label=name)

        #Makes the physical plot
        for i, name in enumerate(self.data_physical_names):
            analog_axis.step(self.data_physical_time[i],
                             self.data_physical[i],
                             ls='-',
                             lw=1.75,
                             where='post',
                             label=name)

        analog_axis.axhline(0, color='black', lw=2)
        analog_axis.legend(bbox_to_anchor=(1.01, 1.01),
                           loc=2,
                           prop={'size': 10})
        analog_axis.set_xlabel('Time(ms)')
        analog_axis.set_ylabel('Voltage(V) / Physical(?)')
        analog_axis.get_xaxis().set_minor_locator(
            matplotlib.ticker.AutoMinorLocator())
        analog_axis.grid(True, which='both')

        #Take care of the Yaxis range of the analog plot
        if not self.autorangeY:
            analog_axis.set_ylim(self.plot_rangeY_min, self.plot_rangeY_max)
        else:
            axismin = min(analog_axis.get_ylim())
            axismax = max(analog_axis.get_ylim())
            analog_axis.set_ylim(axismin, axismax)
            self.plot_rangeY_max = axismax
            self.plot_rangeY_min = axismin

        #Take care of the Xaxis(time) range for both plots
        if not self.autorange:
            analog_axis.set_xlim(self.plot_range_min, self.plot_range_max)
            digi_axis.set_xlim(self.plot_range_min, self.plot_range_max)
        else:
            axismin = min(analog_axis.get_xlim() + digi_axis.get_xlim())
            axismax = max(analog_axis.get_xlim() + digi_axis.get_xlim())
            analog_axis.set_xlim(axismin, axismax)
            digi_axis.set_xlim(axismin, axismax)
            #~ analog_axis.set_ylim(bottom=0,top=11)
            self.plot_range_max = axismax
            self.plot_range_min = axismin

        wx.CallAfter(self.figure.canvas.draw)
示例#11
0
class MplPlot(BasePlot, HasTraitsGroup):
    figure = Instance(Figure, ())
    _draw_pending = Bool(False)

    scale = Enum('linear', 'log', 'sqrt')('linear')
    scale_values = [
        'linear', 'log', 'sqrt'
    ]  # There's probably a way to exract this from the Enum trait but I don't know how
    azimuth = Range(-90, 90, -70)
    elevation = Range(0, 90, 30)
    quality = Range(1, MAX_QUALITY, 1)
    flip_order = Bool(False)
    x_lower = Float(0.0)
    x_upper = Float
    x_label = Str('Angle (2$\Theta$)')
    y_label = Str('Dataset')
    z_lower = Float(0.0)
    z_upper = Float
    z_label = Str
    z_labels = {}  # A dictionary to hold edited labels for each scaling type

    group = VGroup(
        HGroup(
            VGroup(
                Item('azimuth',
                     editor=DefaultOverride(mode='slider',
                                            auto_set=False,
                                            enter_set=True)),
                Item('elevation',
                     editor=DefaultOverride(mode='slider',
                                            auto_set=False,
                                            enter_set=True)),
                Item('quality'),
                Item('flip_order'),
            ),
            VGroup(
                HGroup(
                    Item('x_label',
                         editor=DefaultOverride(auto_set=False,
                                                enter_set=True)),
                    Item('x_lower',
                         editor=DefaultOverride(auto_set=False,
                                                enter_set=True)),
                    Item('x_upper',
                         editor=DefaultOverride(auto_set=False,
                                                enter_set=True)),
                ),
                HGroup(Item('y_label'), ),
                HGroup(
                    Item('z_label',
                         editor=DefaultOverride(auto_set=False,
                                                enter_set=True)),
                    Item('z_lower',
                         editor=DefaultOverride(auto_set=False,
                                                enter_set=True)),
                    Item('z_upper',
                         editor=DefaultOverride(auto_set=False,
                                                enter_set=True)),
                ),
            ),
        ),
        UItem('figure', editor=MPLFigureEditor()),
    )

    def __init__(self, callback_obj=None, *args, **kws):
        super(MplPlot, self).__init__(*args, **kws)
        self.figure = plt.figure()
        self.figure.subplots_adjust(bottom=0.05, left=0, top=1, right=0.95)
        self.ax = None
        for s in self.scale_values:
            self.z_labels[s] = 'Intensity - ' + get_value_scale_label(s,
                                                                      mpl=True)
        # This must be a weak reference, otherwise the entire app will
        # hang on exit.
        from weakref import proxy
        if callback_obj:
            self._callback_object = proxy(callback_obj)
        else:
            self._callback_object = lambda *args, **kw: None

    def close(self):
        del self._callback_object
        plt.close()

    def __del__(self):
        plt.close()

    @on_trait_change('azimuth, elevation')
    def _perspective_changed(self):
        if self.ax:
            self.ax.view_init(azim=self.azimuth, elev=self.elevation)
            self.redraw()

    def _quality_changed(self):
        self.redraw(replot=True)

    @on_trait_change(
        'x_label, y_label, x_lower, x_upper, z_lower, z_upper, flip_order')
    def _trigger_redraw(self):
        self.quality = 1
        self.redraw(replot=True)

    def _z_label_changed(self):
        self.z_labels[self.scale] = self.z_label
        self._trigger_redraw()

    def redraw(self, replot=False, now=False):
        if not now and self._draw_pending:
            self._redraw_timer.Restart()
            return
        #import wx
        canvas = self.figure.canvas
        if canvas is None:
            return

        def _draw():
            self._callback_object._on_redraw(drawing=True)
            if replot:
                self._plot(self.x, self.y, self.z, self.scale)
            else:
                canvas.draw()
            self._draw_pending = False
            self._callback_object._on_redraw(drawing=False)

        if now:
            _draw()
        else:
            _draw()
            #self._redraw_timer = wx.CallLater(250, _draw)
            #self._draw_pending = True
            #self._redraw_timer.Start()

#    def _prepare_data(self, datasets):

    def _prepare_data(self, stack):
        #        stack = stack_datasets(datasets)

        x = stack[:, :, 0]
        z = stack[:, :, 1]
        #        y = array([ [i]*z.shape[1] for i in range(1, len(datasets) + 1) ])
        y = array([[i] * z.shape[1] for i in range(1, stack.shape[1] + 1)])

        if x[0, 0] < x[0, -1]:
            self.x_lower = x[0, 0]
            self.x_upper = x[0, -1]
        else:
            self.x_lower = x[0, -1]
            self.x_upper = x[0, 0]
        self.z_upper = z.max()
        return x, y, z

    def _plot(self, x, y, z, scale='linear'):
        self.x, self.y, self.z = x, y, z
        x, y, z = x.copy(), y.copy(), z.copy()

        if self.flip_order:
            z = z[::-1]
        self.scale = scale
        self.figure.clear()
        self.figure.set_facecolor('white')
        ax = self.ax = self.figure.add_subplot(111, projection='3d')
        ax.set_xlabel(self.x_label)
        ax.set_ylabel(self.y_label)
        self.z_label = self.z_labels[self.scale]
        ax.set_zlabel(self.z_label)

        y_rows = z.shape[0]
        ax.locator_params(axis='y', nbins=10, integer=True)
        ax.view_init(azim=self.azimuth, elev=self.elevation)

        if self.quality != MAX_QUALITY:
            # map quality from 1->5 to 0.05->0.5 to approx. no. of samples
            samples = int(z.shape[1] * ((self.quality - 1) * (0.5 - 0.05) /
                                        (5 - 1) + 0.05))
            z, truncate_at, bins = rebin_preserving_peaks(z, samples / 2)
            # Take the x's from the original x's to maintain visual x-spacing
            # We need to calculate the x's for the rebinned data
            x0_row = x[0, :truncate_at]
            old_xs = np.linspace(x0_row.min(), x0_row.max(), bins * 2)
            new_xs = np.interp(
                old_xs, np.linspace(x0_row.min(), x0_row.max(), len(x0_row)),
                x0_row)
            x = np.tile(new_xs, (y.shape[0], 1))

        # Set values to inf to avoid rendering by matplotlib
        x[(x < self.x_lower) | (x > self.x_upper)] = np.inf
        z[(z < self.z_lower) | (z > self.z_upper)] = np.inf

        # separate series with open lines
        ys = y[:, 0]
        points = []
        for x_row, z_row in zip(x, z):
            points.append(zip(x_row, z_row))
        lines = LineCollection(points)
        ax.add_collection3d(lines, zs=ys, zdir='y')
        ax.set_xlim3d(self.x_lower, self.x_upper)
        ax.set_ylim3d(1, y_rows)
        ax.set_zlim3d(self.z_lower, self.z_upper)
        self.figure.canvas.draw()
        return None

    def copy_to_clipboard(self):
        self.figure.canvas.Copy_to_Clipboard()

    def save_as(self, filename):
        self.figure.canvas.print_figure(filename)
        logger.logger.info('Saved plot {}'.format(filename))

    def _reset_view(self):
        self.azimuth = -70
        self.elevation = 30
class DataPlotEditorBase(HasTraits):

    figure = Instance(Figure, (),transient=True)
    axs = List([],transient=True)
    nplots = Int(1)
    layout = Enum('vertical',['horizontal','vertical'])

    view = View(Item('figure', editor=MPLFigureEditor(),
                     show_label=False,
                     height=400,
                     width=400,
                     style='custom',
                     springy=True),
                handler=MPLInitHandler,
                resizable=True,
                kind='live',
                     )

    def __init__(self):
        super(DataPlotEditorBase, self).__init__()
        self.figure.patch.set_facecolor('none')
        #self.add_subplots(self.nplots)


    def mpl_setup(self):
        self.figure.patch.set_facecolor('none')
        self.add_subplots(self.nplots)

    def remove_subplots(self):
        self.figure.clf(keep_observers=True)
        self.axs = []
        #self.display.figure.canvas.draw()

    def remove_figure(self):
        plt.close(self.figure)

    def add_subplots(self, num):
        self.figure.patch.set_facecolor('none')
        self.axs = []
        for n in range(1, num + 1):
            if self.layout == 'vertical':
                self.axs.append(self.figure.add_subplot(num, 1, n, facecolor='#F4EAEA',zorder=2))
            elif self.layout == 'horizontal':
                self.axs.append(self.figure.add_subplot(1, num,n , facecolor='#F4EAEA',zorder=2)) #FFFFCC
        return self.axs

    def add_common_labels(self,xlabel=None,ylabel=None):
        ax = self.figure.add_subplot(111, facecolor='none', frameon=False,zorder=1)
        ax.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
        ax.grid(b=False)
        if xlabel is not None:
            ax.set_xlabel(xlabel)
        if ylabel is not None:
            ax.set_ylabel(ylabel)
        self.figure.sca(self.axs[0])

    def clear_plots(self):
        for ax in self.axs:
            ax.cla(keep_observers=True)
        self.figure.canvas.draw()

    def set_title(self, title=' ', size=13,y=0.98):
        self.figure.suptitle(title,fontsize=size,y=y)