Beispiel #1
0
 def do_proj(self):
     cons_view = View(Item('axis'),
                      Item('field', editor=EnumEditor(name='field_list')),
                      Item('weight_field',
                           editor=EnumEditor(name='none_field_list')),
                      buttons=OKCancelButtons,
                      title="Projector: %s" % self.pf)
     ps = ProjPlotSpec(pf=self.pf)
     hand = PlotCreationHandler(main_window=mw,
                                pnode=self,
                                model=ps,
                                plot_type=ProjPlotTab,
                                format="Proj: %s")
     ps.edit_traits(cons_view, handler=hand)
Beispiel #2
0
    def _get_table_editor ( self, names ):
        """ Returns a table editor to use for editing the filter.
        """
        from enthought.traits.ui.api import TableEditor

        names       = self._object.editable_traits()
        name_editor = EnumEditor( values = names )
        if len( self.rules ) == 0:
            self.rules  = [ GenericTableFilterRule(
                                filter      = self,
                                name_editor = name_editor ).set(
                                name        = name )
                            for name in names ]
            for rule in self.rules:
                rule.enabled = False

        return TableEditor( columns        = menu_table_filter_rule_columns,
                            orientation    = 'vertical',
                            deletable      = True,
                            sortable       = False,
                            configurable   = False,
                            auto_size      = False,
                            auto_add       = True,
                            row_factory    = GenericTableFilterRule,
                            row_factory_kw = {
                                'filter':      self,
                                'name_editor': name_editor  } )
Beispiel #3
0
    def default_traits_view(self):
        if self.model is None:
            g = Group()
            g.content.append(Label('No Model Selected'))
        else:
            #g = Group(label=self.modelname,show_border=False,orientation='horizontal',layout='flow')
            g = Group(label=self.modelname,show_border=True,orientation='vertical')
            hg = HGroup(Item('fittype',label='Fit Technique',
                             editor=EnumEditor(name='fittypes')))
            g.content.append(hg)
            gp = HGroup(scrollable=True)
            for p in self.model.params:
                gi = Group(orientation='horizontal',label=p)
                self.add_trait(p,Float)
                setattr(self,p,getattr(self.model,p))
                self.on_trait_change(self._param_change_handler,p)
                gi.content.append(Item(p,show_label=False))

                ffp = 'fixfit_'+p
                self.add_trait(ffp,Bool)
                #default to fixed if the paramtere is a class-level fixed model
                setattr(self,ffp,p in self.model.__class__.fixedpars)
                self.on_trait_change(self._param_change_handler,ffp)
                gi.content.append(Item(ffp,label='Fix?'))

                gp.content.append(gi)
            g.content.append(gp)

        return View(g,buttons=['Apply','Revert','OK','Cancel'])
Beispiel #4
0
class NewModelSelector(HasTraits):
    modelnames = List
    selectedname = Str('No Model')
    modelargnum = Int(2)
    selectedmodelclass = Property
    isvarargmodel = Property(depends_on='modelnames')

    traits_view = View(Item('selectedname',label='Model Name:',editor=EnumEditor(name='modelnames')),
                       Item('modelargnum',label='Extra Parameters:',enabled_when='isvarargmodel'),
                       buttons=['OK','Cancel'])

    def __init__(self,include_models=None,exclude_models=None,**traits):
        super(NewModelSelector,self).__init__(**traits)

        self.modelnames = list_models(include_models,exclude_models,FunctionModel1D)
        self.modelnames.insert(0,'No Model')
        self.modelnames.sort()

    def _get_selectedmodelclass(self):
        n = self.selectedname
        if n == 'No Model':
            return None
        else:
            return get_model_class(n)

    def _get_isvarargmodel(self):
        cls = self.selectedmodelclass

        if cls is None:
            return False
        else:
            return cls.isVarnumModel()
Beispiel #5
0
class DataFitterPanel(DataFitter):
    """Fitter panel object, for data fitting with gui for fit function selection
    
    >>> import numpy
    >>> x = numpy.linspace(0,100)
    >>> y = numpy.linspace(0,100) + numpy.random.randn(50)
    >>> data = FitData(x = x, y = y)
    >>> t = DataFitter(data = data)
    
    >>> p,c =  t.fit()
    """
    category = Enum(list(CATEGORIES.keys()))
    function_names = Property(List(Str), depends_on='category')
    function_name = Str
    description = DelegatesTo('function')

    view = View(Group('category',
                      Item(name='function_name',
                           editor=EnumEditor(name='function_names'),
                           id='function_name_edit'),
                      Item('description', style='custom'),
                      label='Fit Function'),
                data_fitter_group,
                resizable=False)

    def _function_name_changed(self, name):
        self.function.function = getattr(CATEGORIES[self.category], name)

    def _get_function_names(self):
        function_names = CATEGORIES[self.category].FUNCTIONS
        self.function_name = function_names[0]
        return function_names
Beispiel #6
0
class RemapDemo(HasTraits):
    surf_func = Str()
    func_list = List([
        "np.sqrt(8- x**2 - y**2)",
        "np.sin(6*np.sqrt(x**2+y**2))",
        "np.sin(6*x)",
        "np.sin(6*y)",
        "np.sin(np.sqrt(x**2+y**2))/np.sqrt(x**2+y**2)",
    ])
    range = Range(1.0, 100.0)
    view_height = Range(1.0, 50.0, 10.0)
    grid = Bool(True)

    view = View(Item("surf_func",
                     label="曲面函数",
                     editor=EnumEditor(name="func_list",
                                       auto_set=False,
                                       evaluate=lambda x: x)),
                Item("range", label="曲面范围"),
                Item("view_height", label="视点高度"),
                Item("grid", label="显示网格"),
                title="Remap Demo控制面板")

    def __init__(self, *args, **kwargs):
        super(RemapDemo, self).__init__(*args, **kwargs)
        self.img = cv.imread("lena.jpg")
        self.size = self.img.size()
        self.w, self.h = self.size.width, self.size.height
        self.dstimg = cv.Mat()
        self.map1 = cv.Mat(self.size, cv.CV_32FC1)
        self.map2 = cv.Mat(self.size, cv.CV_32FC1)
        self.gridimg = self.make_grid_img()
        self.on_trait_change(self.redraw, "surf_func,range,view_height,grid")

    def redraw(self):
        def func(x, y):
            return eval(self.surf_func, globals(), locals())

        try:
            self.map1[:], self.map2[:] = make_surf_map(func, self.range,
                                                       self.w, self.h,
                                                       self.view_height)
        except SyntaxError:
            return
        if self.grid:
            img = self.gridimg
        else:
            img = self.img
        cv.remap(img, self.dstimg, self.map1, self.map2, cv.INTER_LINEAR)
        cv.imshow("Remap Demo", self.dstimg)

    def make_grid_img(self):
        img = self.img.clone()
        for i in range(0, self.w, 30):
            cv.line(img, cv.Point(i, 0), cv.Point(i, self.h),
                    cv.CV_RGB(0, 0, 0), 1)
        for i in range(0, self.h, 30):
            cv.line(img, cv.Point(0, i), cv.Point(self.w, i),
                    cv.CV_RGB(0, 0, 0), 1)
        return img
Beispiel #7
0
class TestEnumEditor(HasTraits):

    #---------------------------------------------------------------------------
    #  Trait definitions:
    #---------------------------------------------------------------------------

    value = Trait(1,
                  enum,
                  range,
                  editor=EnumEditor(values=values, evaluate=int))
Beispiel #8
0
    def create_editor ( self):
        #from enthought.traits.ui.api import EnumEditor
        #print dict(self.wrdict.items())

        ed = EnumEditor( values   = self,
                           cols     = self.cols or 3,
                           evaluate = self.evaluate,
                           mode     = self.mode or 'radio' )

        return ed
Beispiel #9
0
class Equalizers(HasTraits):
    eqs = List(Equalizer, [Equalizer()])
    h = Array(dtype=np.complex, transient=True)

    # Equalizer列表eqs的编辑器定义
    table_editor = TableEditor(columns=[
        ObjectColumn(name="freq", width=0.4, style="readonly"),
        ObjectColumn(name="Q", width=0.3, style="readonly"),
        ObjectColumn(name="gain", width=0.3, style="readonly"),
    ],
                               deletable=True,
                               sortable=True,
                               auto_size=False,
                               show_toolbar=True,
                               edit_on_first_click=False,
                               orientation='vertical',
                               edit_view=View(Group(
                                   Item("freq",
                                        editor=EnumEditor(values=EQ_FREQS)),
                                   Item("freq", editor=scrubber(1.0)),
                                   Item("Q", editor=scrubber(0.01)),
                                   Item("gain", editor=scrubber(0.1)),
                                   show_border=True,
                               ),
                                              resizable=True),
                               row_factory=Equalizer)

    view = View(Item("eqs", show_label=False, editor=table_editor),
                width=0.25,
                height=0.5,
                resizable=True)

    @on_trait_change("eqs.h")
    def recalculate_h(self):
        '''计算多组均衡器级联时的频率响应'''
        try:
            tmp = np.array([
                eq.h for eq in self.eqs if eq.h != None and len(eq.h) == len(W)
            ])
            self.h = np.prod(tmp, axis=0)
        except:
            pass

    def export(self, path):
        '''将均衡器的系数输出为C语言文件'''
        f = file(path, "w")
        f.write("double EQ_PARS[][5] = {\n")
        f.write("//b0,b1,b2,a1,a2 // frequency, Q, gain\n")
        for eq in self.eqs:
            eq.export_parameters(f)
        f.write("};\n")
        f.close()
Beispiel #10
0
class EnumExample(HasTraits):
    priority = Enum('Medium', 'Highest', 'High', 'Medium', 'Low', 'Lowest')

    view = View(
        Item(name='priority',
             editor=EnumEditor(
                 values={
                     'Highest': '1:Highest',
                     'High': '2:High',
                     'Medium': '3:Medium',
                     'Low': '4:Low',
                     'Lowest': '5:Lowest',
                 })))
Beispiel #11
0
 def do_slice(self):
     cons_view = View(Item('axis'),
                      Item('center'),
                      Item('field', editor=EnumEditor(name='field_list')),
                      buttons=OKCancelButtons,
                      title="Slicer: %s" % self.pf)
     ps = SlicePlotSpec(pf=self.pf)
     hand = PlotCreationHandler(main_window=mw,
                                pnode=self,
                                model=ps,
                                plot_type=SlicePlotTab,
                                format="Slice: %s")
     ps.edit_traits(cons_view, handler=hand)
Beispiel #12
0
class TableTest(HasStrictTraits):

    #---------------------------------------------------------------------------
    #  Trait definitions:
    #---------------------------------------------------------------------------

    #people = Instance( Person )
    people = List(Person)

    #---------------------------------------------------------------------------
    #  Traits view definitions:
    #---------------------------------------------------------------------------

    _valid_states = List(["AL", "AR", "AZ", "AK"])

    _state_editor = EnumEditor(name="_valid_states",
                               evaluate=evaluate_value,
                               object='table_editor_object')

    table_editor = TableEditor(columns=[
        ObjectColumn(name='name'),
        ObjectColumn(name='age'),
        ObjectColumn(name='phone'),
        ObjectColumn(name='state', editor=_state_editor),
    ],
                               editable=True,
                               deletable=True,
                               sortable=True,
                               sort_model=True,
                               show_lines=True,
                               orientation='vertical',
                               show_column_labels=True,
                               edit_view=View(
                                   ['name', 'age', 'phone', 'state', '|[]'],
                                   resizable=True),
                               filter=None,
                               filters=filters,
                               row_factory=Person)

    traits_view = View(
        [Item('people', id='people', editor=table_editor), '|[]<>'],
        title='Table Editor Test',
        id='enthought.traits.ui.tests.table_editor_test',
        dock='horizontal',
        width=.4,
        height=.3,
        resizable=True,
        buttons=NoButtons,
        kind='live')
Beispiel #13
0
 def default_traits_view(self):
     view = View(
         VGroup(
             HGroup(
                 Item("current_map", label=u"颜色映射", editor=EnumEditor(name="object.color_maps")),
                 Item("reverse_map", label=u"反转颜色"),
                 Item("position", label=u"位置", style="readonly"),
             ),
             Item("plot", show_label=False, editor=ComponentEditor()),
         ),
         resizable = True,
         width = 550, height = 300,
         title = u"Mandelbrot观察器"
     )
     return view
Beispiel #14
0
    def createView(self):
        """Set up a view for the traits."""
        indexItems = [
            Item('index0',
                 editor=EnumEditor(name='index0_enum'),
                 visible_when=self.runtimeElementCondition +
                 self.numDimensionsString + '>=1'),
            Item('index1',
                 editor=EnumEditor(name='index1_enum'),
                 visible_when=self.runtimeElementCondition +
                 self.numDimensionsString + '>=2'),
            Item('index2',
                 editor=EnumEditor(name='index2_enum'),
                 visible_when=self.runtimeElementCondition +
                 self.numDimensionsString + '>=3'),
        ]

        items = [Item('item')]
        #items = [Item('item')] + indexItems[:self.getMaxDimensions()]

        self.view = View(Group(
            *items, **dict(orientation='horizontal', show_labels=False)),
                         buttons=NoButtons,
                         handler=SelectorHandler)
Beispiel #15
0
    def _get_table_editor ( self, names ):
        """ Returns a table editor to use for editing the filter.
        """
        from enthought.traits.ui.api import TableEditor

        return TableEditor( columns        = generic_table_filter_rule_columns,
                            orientation    = 'vertical',
                            deletable      = True,
                            sortable       = False,
                            configurable   = False,
                            auto_size      = False,
                            auto_add       = True,
                            row_factory    = GenericTableFilterRule,
                            row_factory_kw = {
                                'filter':      self,
                                'name_editor': EnumEditor( values = names ) } )
Beispiel #16
0
class Address(HasTraits):
    """ Demo class for demonstrating dynamic redefinition of valid trait values.
    """

    street_address = Str
    st = Enum(cities.keys()[0], cities.keys())
    city = Str

    view = View(Item(name='street_address'),
                Item(name='st', label='State'),
                Item(name='city',
                     editor=EnumEditor(name='handler.cities'),
                     id='cityedit'),
                title='Address Information',
                buttons=['OK'],
                resizable=True,
                handler=AddressHandler)
Beispiel #17
0
class Constraint(HasTraits):
    name = String
    runcase = Any
    constraint_variables_available = Property(List(String), depends_on='runcase.constraint_variables')
    @cached_property
    def _get_constraint_variables_available(self):
        return self.runcase.constraint_variables.keys()
    #constraint = Instance(ConstraintVariable)
    constraint_name = String
    value = Float
    pattern = String
    cmd = String
    editor = TableEditor(
        auto_size=False,
        columns=[ ObjectColumn(name='name', editable=False, label='Parameter'),
                     ObjectColumn(name='constraint_name', label='Constraint', editor=EnumEditor(name='constraint_variables_available')),
                     ObjectColumn(name='value', label='Value', editor=TextEditor(evaluate=float, enter_set=True, auto_set=False))
                    ])
Beispiel #18
0
class TrimCaseConfig(RunOptionsConfig):
    trimcase = Instance(TrimCase, TrimCase())
    #runcase = DelegatesTo('trimcase')
    varying_param = String('velocity')
    varying_expr = Expression('100')
    param_name_list = Property(List(String), depends_on='trimcase.parameters')

    @cached_property
    def _get_param_name_list(self):
        return sorted(self.trimcase.parameters.keys(), key=lambda x: x.lower())

    @on_trait_change('trimcase.parameter_view.value,trimcase.type')
    def on_trimcase_changed(self, object, name, old, new):
        print 'trimcase_changed'
        if name == 'value':
            self.trimcase.runcase.update_trim_case(self.trimcase, object)
            self.trimcase.update_parameters_from_avl()
        elif name == 'type':
            self.trimcase.update_parameters_from_avl()
            self.param_name_list[0]

    # send changed value of trimcase to avl, parameter is Parameter instance
    def update_trim_case(self, trimcase, parameter):
        print 'update_trim_case'
        self.get_parameters_info_from_avl(self.avl)
        self.avl.sendline('oper')
        self.avl.sendline(trimcase.type_)
        self.avl.expect(AVL.patterns['/oper/m'])
        #print self.parameters.keys()
        #for p, v in trimcase.parameters.iteritems():
        self.avl.sendline('%s %f' % (parameter.cmd, parameter.value))
        AVL.goto_state(self.avl)

    custom_group = Group(
        Item('object.trimcase.type'),
        Group(Item('object.trimcase.parameter_view',
                   editor=Parameter.editor,
                   show_label=False,
                   height=0.4),
              label='Parameters',
              scrollable=True),
        #Item('trimcase', style='custom', show_label=False),
        Group(Item('varying_param', editor=EnumEditor(name='param_name_list')),
              Item('varying_expr', label='Expression')))
Beispiel #19
0
class AnimatedGIFDemo(HasTraits):

    # The animated GIF file to display:
    gif_file = File(files[0])

    # Is the animation playing or not?
    playing = Bool(True)

    # The traits view:
    view = View(VGroup(
        HGroup(
            Item('gif_file',
                 editor=AnimatedGIFEditor(playing='playing'),
                 show_label=False),
            Item('playing'),
        ), '_',
        Item('gif_file', label='GIF File', editor=EnumEditor(values=files))),
                title='Animated GIF Demo',
                buttons=['OK'])
Beispiel #20
0
class Sp4ArrayFileSource(ArraySource):

    file_name = Str

    component = Enum('Amp', 'Phase', 'Real', 'Imag')
    set_component = Str

    f = Instance(sp4.Sp4File)

    view = View(Group(Item(name='transpose_input_array'),
        Item(name='file_name', editor=TextEditor(auto_set=False,\
                enter_set=True)),
        Item(name='component', editor=EnumEditor(values=component)),
                      Item(name='scalar_name'),
                      Item(name='vector_name'),
                      Item(name='spacing'),
                      Item(name='origin'),
                      Item(name='update_image_data', show_label=False),
                      show_labels=True)
                )

    def __init__(self, **kw_args):
        super(Sp4ArrayFileSource, self).__init__(**kw_args)
        fn = kw_args.pop('file_name', None)
        if fn is not None:
            self.file_name = fn
            self._open(self.file_name)
        self.component = "Amp"
        self._component_changed('Amp')

    def _open(self, fn):
        self.f = sp4.Sp4File(self.file_name)

    def _component_changed(self, info):
        if info == "Amp":
            self.scalar_data = numpy.abs(self.f.GetArray())
        if info == "Phase":
            self.scalar_data = numpy.angle(self.f.GetArray())
        self.update()

    def _file_name_changed(self, info):
        print self.file_name
Beispiel #21
0
class ConstraintConfig(HasTraits):
    constraint = Instance(Constraint)
    expr = Expression()

    def __init__(self, *args, **kwargs):
        super(ConstraintConfig, self).__init__(*args, **kwargs)
        self.expr = str(self.constraint.value)

    constraint_variables_available = DelegatesTo('constraint')
    editor = TableEditor(
        auto_size=False,
        columns=[
            ObjectColumn(name='constraint.name',
                         editable=False,
                         label='Parameter'),
            ObjectColumn(
                name='constraint.constraint_name',
                label='Constraint',
                editor=EnumEditor(name='constraint_variables_available')),
            ObjectColumn(name='expr', label='Expression')
        ])
Beispiel #22
0
class HierarchyImporter(HasTraits):
    pf = Any
    min_grid_level = Int(0)
    max_level = Int(1)
    number_of_levels = Range(0, 13)
    max_import_levels = Property(depends_on='min_grid_level')
    field = Str("Density")
    field_list = List
    center_on_max = Bool(True)
    center = CArray(shape=(3, ), dtype='float64')
    cache = Bool(True)
    smoothed = Bool(True)
    show_grids = Bool(True)

    def _field_list_default(self):
        fl = self.pf.h.field_list
        df = self.pf.h.derived_field_list
        fl.sort()
        df.sort()
        return fl + df

    default_view = View(
        Item('min_grid_level',
             editor=RangeEditor(low=0, high_name='max_level')),
        Item('number_of_levels',
             editor=RangeEditor(low=1, high_name='max_import_levels')),
        Item('field', editor=EnumEditor(name='field_list')),
        Item('center_on_max'),
        Item('center', enabled_when='not object.center_on_max'),
        Item('smoothed'),
        Item('cache', label='Pre-load data'),
        Item('show_grids'),
        buttons=OKCancelButtons)

    def _center_default(self):
        return [0.5, 0.5, 0.5]

    @cached_property
    def _get_max_import_levels(self):
        return min(13, self.pf.h.max_level - self.min_grid_level + 1)
Beispiel #23
0
class Graph(HasTraits):
    """
    绘图组件,包括左边的数据选择控件和右边的绘图控件
    """
    name = Str # 绘图名,显示在标签页标题和绘图标题中
    data_source = Instance(DataSource) # 保存数据的数据源
    figure = Instance(Figure) # 控制绘图控件的Figure对象
    selected_xaxis = Str # X轴所用的数据名
    selected_items = List # Y轴所用的数据列表

    clear_button = Button(u"清除") # 快速清除Y轴的所有选择的数据

    view = View(
        HSplit( # HSplit分为左右两个区域,中间有可调节宽度比例的调节手柄
            # 左边为一个组
            VGroup(
                Item("name"),   # 绘图名编辑框
                Item("clear_button"), # 清除按钮
                Heading(u"X轴数据"),  # 静态文本
                # X轴选择器,用EnumEditor编辑器,即ComboBox控件,控件中的候选数据从
                # data_source的names属性得到
                Item("selected_xaxis", editor=
                    EnumEditor(name="object.data_source.names", format_str=u"%s")),
                Heading(u"Y轴数据"), # 静态文本
                # Y轴选择器,由于Y轴可以多选,因此用CheckBox列表编辑,按两列显示
                Item("selected_items", style="custom", 
                     editor=CheckListEditor(name="object.data_source.names", 
                            cols=2, format_str=u"%s")),
                show_border = True, # 显示组的边框
                scrollable = True,  # 组中的控件过多时,采用滚动条
                show_labels = False # 组中的所有控件都不显示标签
            ),
            # 右边绘图控件
            Item("figure", editor=MPLFigureEditor(), show_label=False, width=600)
        )        
    )

    def _name_changed(self):
        """
        当绘图名发生变化时,更新绘图的标题
        """
        axe = self.figure.axes[0]
        axe.set_title(self.name)
        self.figure.canvas.draw()

    def _clear_button_fired(self):
        """
        清除按钮的事件处理
        """
        self.selected_items = []
        self.update()

    def _figure_default(self):
        """
        figure属性的缺省值,直接创建一个Figure对象
        """
        figure = Figure()
        figure.add_axes([0.1, 0.1, 0.85, 0.80]) #添加绘图区域,四周留有边距
        return figure

    def _selected_items_changed(self):
        """
        Y轴数据选择更新
        """
        self.update()

    def _selected_xaxis_changed(self):
        """
        X轴数据选择更新
        """    
        self.update()

    def update(self):
        """
        重新绘制所有的曲线
        """    
        axe = self.figure.axes[0]
        axe.clear()
        try:
            xdata = self.data_source.data[self.selected_xaxis]
        except:
            return 
        for field in self.selected_items:
            axe.plot(xdata, self.data_source.data[field], label=field)
        axe.set_xlabel(self.selected_xaxis)
        axe.set_title(self.name)
        axe.legend()
        self.figure.canvas.draw()
# -*- coding: utf-8 -*-
Beispiel #25
0
class TVTKClassChooser(HasTraits):

    # The selected object, is None if no valid class_name was made.
    object = Property

    # The TVTK class name to choose.
    class_name = Str('', desc='class name of TVTK class (case sensitive)')

    # The string to search for in the class docs -- the search supports
    # 'and' and 'or' keywords.
    search = Str('', desc='string to search in TVTK class documentation '\
                          'supports the "and" and "or" keywords. '\
                          'press <Enter> to start search. '\
                          'This is case insensitive.')

    clear_search = Button

    # The class documentation.
    doc = Str(_search_help_doc)

    # Completions for the choice of class.
    completions = List(Str)

    # List of available class names as strings.
    available = List(TVTK_CLASSES)

    ########################################
    # Private traits.

    finder = Instance(DocSearch)

    n_completion = Int(25)

    ########################################
    # View related traits.

    view = View(Group(Item(name='class_name',
                           editor=EnumEditor(name='available')),
                      Item(name='class_name',
                           has_focus=True
                           ),
                      Item(name='search',
                           editor=TextEditor(enter_set=True,
                                             auto_set=False)
                           ),
                      Item(name='clear_search',
                           show_label=False),
                      Item('_'),
                      Item(name='completions',
                           editor=ListEditor(columns=3),
                           style='readonly'
                           ),
                      Item(name='doc', 
                           resizable=True,
                           label='Documentation',
                           style='custom')
                      ),
                id='tvtk_doc',
                resizable=True,
                width=800,
                height=600,
                title='TVTK class chooser',
                buttons = ["OK", "Cancel"]
                )
    ######################################################################
    # `object` interface.
    ###################################################################### 
    def __init__(self, **traits):
        super(TVTKClassChooser, self).__init__(**traits)
        self._orig_available = list(self.available)

    ######################################################################
    # Non-public interface.
    ###################################################################### 
    def _get_object(self):
        o = None
        if len(self.class_name) > 0:
            try:
                o = getattr(tvtk, self.class_name)()
            except (AttributeError, TypeError):
                pass
        return o

    def _class_name_changed(self, value):
        av = self.available
        comp = [x for x in av if x.startswith(value)]
        self.completions = comp[:self.n_completion]
        if len(comp) == 1 and value != comp[0]:
            self.class_name = comp[0]

        o = self.object
        if o is not None:
            self.doc = get_tvtk_class_doc(o)
        else:
            self.doc = _search_help_doc

    def _finder_default(self):
        return DocSearch()

    def _clear_search_fired(self):
        self.search = ''

    def _search_changed(self, value):
        if len(value) < 3:
            self.available = self._orig_available
            return

        f = self.finder
        result = f.search(value)
        if len(result) == 0:
            self.available = self._orig_available
        elif len(result) == 1:
            self.class_name = result[0]
        else:
            self.available = result
            self.completions = result[:self.n_completion]
class IFSDesigner(HasTraits):
    plot = Instance(Plot)
    clear = Bool(False)
    draw = Bool(False)
    timer = Instance(Timer)
    ifs_names = List()
    ifs_points = List()
    current_name = Str()
    save_button = Button(u"保存当前IFS")
    unsave_button = Button(u"删除当前IFS")    
    
    view = View(
        HGroup(
            Item("current_name", editor = EnumEditor(name="object.ifs_names")),
            Item("save_button"),                
            Item("unsave_button"),
            show_labels=False
        ),
        Item("plot", editor=ComponentEditor(),show_label=False),
        resizable=True,
        width = 500, 
        height = 500,
        title = u"IFS图形设计器"
    )
    
    def __init__(self):       
        self.data = ArrayPlotData()
        self.set_empty_data()
        self.plot = Plot(self.data, padding=10)
        scatter = self.plot.plot(("x","y", "c"), type="cmap_scatter", 
            marker_size=1, color_mapper=make_color_map(), line_width=0)[0]
        self.plot.x_grid.visible = False
        self.plot.y_grid.visible = False
        self.plot.x_axis.visible = False
        self.plot.y_axis.visible = False
        self.tool = TrianglesTool(self.plot)
        self.plot.overlays.append(self.tool)

        try:
            with file("ifs_chaco.data","rb") as f:
                tmp = pickle.load(f)
                self.ifs_names = [x[0] for x in tmp]                
                self.ifs_points = [np.array(x[1]) for x in tmp]

            if len(self.ifs_names) > 0:
                self.current_name = self.ifs_names[-1]
        except:
            pass 
        
        self.tool.on_trait_change(self.triangle_changed, 'changed')
        self.timer = Timer(10, self.ifs_calculate)       

    def set_empty_data(self):
        self.data["x"] = np.array([])
        self.data["y"] = np.array([])
        self.data["c"] = np.array([])        
        
    def triangle_changed(self):
        count = len(self.tool.points)
        if count % 3 == 0:
            self.set_empty_data()        
            
        if count < 9:
            self.draw = False
            
        if count >= 9 and count % 3  == 0:
            self.clear = True
        
    def ifs_calculate(self):
        if self.clear == True:
            self.clear = False
            self.initpos = [0, 0]
            # 不绘制迭代的初始100个点
            x, y, c = ifs( self.tool.get_areas(), self.tool.get_eqs(), self.initpos, 100)
            self.initpos = [x[-1], y[-1]]
            self.draw = True
        
        if self.draw and len(self.data["x"]) < ITER_COUNT * ITER_TIMES:
            x, y, c = ifs( self.tool.get_areas(), self.tool.get_eqs(), self.initpos, ITER_COUNT)
            ox, oy, oc = self.data["x"], self.data["y"], self.data["c"]
            if np.max(np.abs(x)) < 1000000 and np.max(np.abs(y)) < 1000000:
                self.initpos = [x[-1], y[-1]]
                x, y, z = np.hstack((ox, x)), np.hstack((oy, y)), np.hstack((oc, c))
                self.data["x"], self.data["y"], self.data["c"] = x, y, z
                # 调整绘图范围,保持X-Y轴的比例为1:1
                xmin, xmax = np.min(x), np.max(x)                
                ymin, ymax = np.min(y), np.max(y)
                xptp, yptp = xmax - xmin, ymax-ymin
                xcenter, ycenter =(xmax + xmin) / 2.0 , (ymax + ymin) / 2.0
                w, h = float(self.plot.width), float(self.plot.height)
                scale = max(xptp/w , yptp/h)
                self.plot.index_range.low = xcenter - 0.5*scale*w
                self.plot.index_range.high = xcenter + 0.5*scale*w
                self.plot.value_range.low = ycenter - 0.5*scale*h
                self.plot.value_range.high = ycenter + 0.5*scale

    def _current_name_changed(self):
        index = self.ifs_names.index(self.current_name)
        self.tool.points = list(self.ifs_points[index])
        self.tool.changed = True
        self.clear = True
                
    def _save_button_fired(self):
        """
        保存按钮处理
        """
        ask = AskName(name = self.current_name)
        if ask.configure_traits():
            if ask.name not in self.ifs_names:
                self.ifs_names.append( ask.name )
                self.ifs_points.append( self.tool.points[:] )
            else:
                index = self.ifs_names.index(ask.name)
                self.ifs_names[index] = ask.name
                self.ifs_points[index] = self.tool.points[:]    
            self.save_data()
            self.current_name = ask.name
            
    def _unsave_button_fired(self):
        index = self.ifs_names.index(self.current_name)
        del self.ifs_names[index]
        del self.ifs_points[index]
        if index >= self.ifs_names[index]: index -= 1
        self.current_name = self.ifs_names[index]
        self.save_data()
        
    def save_data(self):               
        with file("IFS_chaco.data", "wb") as f:
            pickle.dump(zip(self.ifs_names, self.ifs_points), f)         
Beispiel #27
0
 def make_editor(trait=None):
     return EnumEditor(name=trait.values_name)
Beispiel #28
0
class ParticleArrayHelper(HasTraits):
    """
    This class manages a particle array and sets up the necessary
    plotting related information for it.
    """

    # The particle array we manage.
    particle_array = Instance(ParticleArray)

    # The name of the particle array.
    name = Str

    # The active scalar to view.
    scalar = Str('rho', desc='name of the active scalar to view')

    # The mlab plot for this particle array.
    plot = Instance(PipelineBase)

    # List of available scalars in the particle array.
    scalar_list = List(Str)

    scene = Instance(MlabSceneModel)

    # Sync'd trait with the scalar lut manager.
    show_legend = Bool(False, desc='if the scalar legend is to be displayed')

    # Sync'd trait with the dataset to turn on/off visibility.
    visible = Bool(True, desc='if the particle array is to be displayed')

    ########################################
    # View related code.
    view = View(
        Item(name='name', show_label=False, editor=TitleEditor()),
        Group(
            Item(name='visible'),
            Item(name='scalar', editor=EnumEditor(name='scalar_list')),
            Item(name='show_legend'),
        ),
    )

    ######################################################################
    # Private interface.
    ######################################################################
    def _particle_array_changed(self, pa):
        self.name = pa.name
        # Setup the scalars.
        self.scalar_list = sorted(pa.properties.keys())

        # Update the plot.
        x, y, z, u, v, w = pa.x, pa.y, pa.z, pa.u, pa.v, pa.w
        s = getattr(pa, self.scalar)
        p = self.plot
        mlab = self.scene.mlab
        if p is None:
            src = mlab.pipeline.vector_scatter(x, y, z, u, v, w, scalars=s)
            p = mlab.pipeline.glyph(src, mode='point', scale_mode='none')
            p.actor.property.point_size = 3
            p.mlab_source.dataset.point_data.scalars.name = self.scalar
            scm = p.module_manager.scalar_lut_manager
            scm.set(show_legend=self.show_legend,
                    use_default_name=False,
                    data_name=self.scalar)
            self.sync_trait('visible', p.mlab_source.m_data, mutual=True)
            self.sync_trait('show_legend', scm, mutual=True)
            #set_arrays(p.mlab_source.m_data, pa)
            self.plot = p
        else:
            if len(x) == len(p.mlab_source.x):
                p.mlab_source.set(x=x, y=y, z=z, scalars=s, u=u, v=v, w=w)
            else:
                p.mlab_source.reset(x=x, y=y, z=z, scalars=s, u=u, v=v, w=w)

    def _scalar_changed(self, value):
        p = self.plot
        if p is not None:
            p.mlab_source.scalars = getattr(self.particle_array, value)
            p.module_manager.scalar_lut_manager.data_name = value
Beispiel #29
0
class DataSourceWizardView(DataSourceWizard):

    #----------------------------------------------------------------------
    # Private traits
    #----------------------------------------------------------------------

    _top_label = Str('Describe your data')

    _info_text = Str('Array size do not match')

    _array_label = Str('Available arrays')

    _data_type_text = Str("What does your data represents?" )

    _lines_text = Str("Connect the points with lines" )

    _scalar_data_text = Str("Array giving the value of the scalars")

    _optional_scalar_data_text = Str("Associate scalars with the data points")

    _connectivity_text = Str("Array giving the triangles")

    _vector_data_text = Str("Associate vector components")

    _position_text = Property(depends_on="position_type_")

    _position_text_dict = {'explicit':
                'Coordinnates of the data points:',
                           'orthogonal grid':
                'Position of the layers along each axis:',
            }

    def _get__position_text(self):
        return self._position_text_dict.get(self.position_type_, "")

    _shown_help_text = Str

    _data_sources_wrappers = Property(depends_on='data_sources')

    def _get__data_sources_wrappers(self):
         return [
            ArrayColumnWrapper(name=name, 
                shape=repr(self.data_sources[name].shape))
                    for name in self._data_sources_names
                ]
            

    # A traits pointing to the object, to play well with traitsUI
    _self = Instance(DataSourceWizard)

    _suitable_traits_view = Property(depends_on="data_type_")

    def _get__suitable_traits_view(self):
        return "_%s_data_view" % self.data_type_

    ui = Any(False)

    _preview_button = Button(label='Preview structure')

    def __preview_button_fired(self):
        if self.ui:
            self.build_data_source()
            self.preview()

    _ok_button = Button(label='OK')

    def __ok_button_fired(self):
        if self.ui:
            self.ui.dispose()
            self.build_data_source()


    _cancel_button = Button(label='Cancel')

    def __cancel_button_fired(self):
        if self.ui:
            self.ui.dispose()

    _is_ok = Bool

    _is_not_ok = Bool

    def _anytrait_changed(self):
        """ Validates if the OK button is enabled.
        """
        if self.ui:
            self._is_ok =  self.check_arrays()
            self._is_not_ok = not self._is_ok
    
    _preview_window = Instance(PreviewWindow, ())

    _info_image = Instance(ImageResource, 
                    ImageLibrary.image_resource('@std:alert16',))

    #----------------------------------------------------------------------
    # TraitsUI views
    #----------------------------------------------------------------------

    _coordinates_group = \
                        HGroup(
                           Item('position_x', label='x',
                               editor=EnumEditor(name='_data_sources_names',
                                        invalid='_is_not_ok')), 
                           Item('position_y', label='y',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                           Item('position_z', label='z',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                       )


    _position_group = \
                    Group(
                       Item('position_type'),
                       Group(
                           Item('_position_text', style='readonly',
                                    resizable=False,
                                    show_label=False),
                           _coordinates_group,
                           visible_when='not position_type_=="image data"',
                       ),
                       Group(
                           Item('grid_shape_source_',
                            label='Grid shape',
                            editor=EnumEditor(
                                name='_grid_shape_source_labels',
                                        invalid='_is_not_ok')), 
                           HGroup(
                            spring,
                            Item('grid_shape', style='custom', 
                                    editor=ArrayEditor(width=-60),
                                    show_label=False),
                           enabled_when='grid_shape_source==""',
                            ),
                           visible_when='position_type_=="image data"',
                       ),
                       label='Position of the data points',
                       show_border=True,
                       show_labels=False,
                   ),


    _connectivity_group = \
                   Group(
                       HGroup(
                         Item('_connectivity_text', style='readonly',
                                resizable=False),
                         spring,
                         Item('connectivity_triangles',
                                editor=EnumEditor(name='_data_sources_names'),
                                show_label=False,
                                ),
                         show_labels=False,
                       ),
                       label='Connectivity information',
                       show_border=True,
                       show_labels=False,
                       enabled_when='position_type_=="explicit"',
                   ),


    _scalar_data_group = \
                   Group(
                       Item('_scalar_data_text', style='readonly', 
                           resizable=False,
                           show_label=False),
                       HGroup(
                           spring,
                           Item('scalar_data', 
                               editor=EnumEditor(name='_data_sources_names',
                                        invalid='_is_not_ok')), 
                           show_labels=False,
                           ),
                       label='Scalar value',
                       show_border=True,
                       show_labels=False,
                   )


    _optional_scalar_data_group = \
                   Group(
                       HGroup(
                       'has_scalar_data',
                       Item('_optional_scalar_data_text',
                            resizable=False,
                            style='readonly'),
                       show_labels=False,
                       ),
                       Item('_scalar_data_text', style='readonly', 
                            resizable=False,
                            enabled_when='has_scalar_data',
                           show_label=False),
                       HGroup(
                           spring, 
                           Item('scalar_data', 
                               editor=EnumEditor(name='_data_sources_names',
                                        invalid='_is_not_ok'), 
                               enabled_when='has_scalar_data'),
                           show_labels=False,
                           ),
                       label='Scalar data',
                       show_border=True,
                       show_labels=False,
                   ),


    _vector_data_group = \
                   VGroup(
                       HGroup(
                           Item('vector_u', label='u',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                           Item('vector_v', label='v',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                           Item('vector_w', label='w',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                       ),
                       label='Vector data',
                       show_border=True,
                   ),


    _optional_vector_data_group = \
                   VGroup(
                        HGroup(
                            Item('has_vector_data', show_label=False),
                            Item('_vector_data_text', style='readonly', 
                                resizable=False,
                                show_label=False),
                        ),
                       HGroup(
                           Item('vector_u', label='u',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                           Item('vector_v', label='v',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                           Item('vector_w', label='w',
                               editor=EnumEditor(name='_data_sources_names', 
                                        invalid='_is_not_ok')), 
                           enabled_when='has_vector_data',
                       ),
                       label='Vector data',
                       show_border=True,
                   ),


    _array_view = \
                View(
                    Item('_array_label', editor=TitleEditor(),
                        show_label=False),
                    Group(    
                    Item('_data_sources_wrappers', 
                      editor=TabularEditor(
                          adapter = ArrayColumnAdapter(),
                      ), 
                    ),
                    show_border=True,
                    show_labels=False
                ))

    _questions_view = View(
                Item('_top_label', editor=TitleEditor(),
                        show_label=False),
                HGroup(
                    Item('_data_type_text', style='readonly',
                                resizable=False),
                    spring,
                    'data_type',
                    spring,
                    show_border=True,
                    show_labels=False,
                  ),
                HGroup(
                    Item('_self', style='custom', 
                        editor=InstanceEditor(
                                    view_name='_suitable_traits_view'),
                        ),
                    Group(
                        # FIXME: Giving up on context sensitive help
                        # because of lack of time.
                        #Group(
                        #    Item('_shown_help_text', editor=HTMLEditor(), 
                        #        width=300,
                        #        label='Help',
                        #        ),
                        #    show_labels=False,
                        #    label='Help',
                        #),
                        #Group(
                            Item('_preview_button', 
                                    enabled_when='_is_ok'),
                            Item('_preview_window', style='custom',
                                    label='Preview structure'),
                            show_labels=False,
                            #label='Preview structure',
                        #),
                        #layout='tabbed',
                        #dock='tab',
                    ),
                    show_labels=False,
                    show_border=True,
                ),
            )

    _point_data_view = \
                View(Group(
                   Group(_coordinates_group,
                        label='Position of the data points',
                        show_border=True,
                   ),
                   HGroup(
                       'lines',
                       Item('_lines_text', style='readonly',
                                        resizable=False), 
                       label='Lines',
                       show_labels=False,
                       show_border=True,
                   ),
                   _optional_scalar_data_group,
                   _optional_vector_data_group,
                   # XXX: hack to have more vertical space
                   Label('\n'),
                   Label('\n'),
                   Label('\n'),
                ))


    _surface_data_view = \
                View(Group(
                   _position_group,
                   _connectivity_group,
                   _optional_scalar_data_group,
                   _optional_vector_data_group,
                ))


    _vector_data_view = \
                View(Group(
                   _vector_data_group,
                   _position_group,
                   _optional_scalar_data_group,
                ))


    _volumetric_data_view = \
                View(Group(
                   _scalar_data_group,
                   _position_group,
                   _optional_vector_data_group,
                ))


    _wizard_view = View(
          Group(
            HGroup(
                Item('_self', style='custom', show_label=False,
                     editor=InstanceEditor(view='_array_view'),
                     width=0.17,
                     ),
                '_',
                Item('_self', style='custom', show_label=False,
                     editor=InstanceEditor(view='_questions_view'),
                     ),
                ),
            HGroup(
                Item('_info_image', editor=ImageEditor(),
                    visible_when="_is_not_ok"),
                Item('_info_text', style='readonly', resizable=False,
                    visible_when="_is_not_ok"),
                spring, 
                '_cancel_button', 
                Item('_ok_button', enabled_when='_is_ok'),
                show_labels=False,
            ),
          ),
        title='Import arrays',
        resizable=True,
        )


    #----------------------------------------------------------------------
    # Public interface
    #----------------------------------------------------------------------

    def __init__(self, **traits):
        DataSourceFactory.__init__(self, **traits)
        self._self = self


    def view_wizard(self):
        """ Pops up the view of the wizard, and keeps the reference it to
            be able to close it.
        """
        # FIXME: Workaround for traits bug in enabled_when
        self.position_type_
        self.data_type_
        self._suitable_traits_view
        self.grid_shape_source
        self._is_ok
        self.ui = self.edit_traits(view='_wizard_view')


    def preview(self):
        """ Display a preview of the data structure in the preview
            window.
        """
        self._preview_window.clear()
        self._preview_window.add_source(self.data_source)
        data = lambda name: self.data_sources[name]
        g = Glyph()
        g.glyph.glyph_source.glyph_source = \
                    g.glyph.glyph_source.glyph_list[0]
        g.glyph.scale_mode = 'data_scaling_off'
        if not (self.has_vector_data or self.data_type_ == 'vector'):
            g.glyph.glyph_source.glyph_source.glyph_type = 'cross'
            g.actor.property.representation = 'points'
            g.actor.property.point_size = 3.
        self._preview_window.add_module(g)
        if not self.data_type_ in ('point', 'vector') or self.lines:
            s = Surface()
            s.actor.property.opacity = 0.3
            self._preview_window.add_module(s)
        if not self.data_type_ == 'point':
            self._preview_window.add_filter(ExtractEdges())
            s = Surface()
            s.actor.property.opacity = 0.2
            self._preview_window.add_module(s)
Beispiel #30
0
class FitGui(HasTraits):
    """
    This class represents the fitgui application state.
    """

    plot = Instance(Plot)
    colorbar = Instance(ColorBar)
    plotcontainer = Instance(HPlotContainer)
    tmodel = Instance(TraitedModel,allow_none=False)
    nomodel = Property
    newmodel = Button('New Model...')
    fitmodel = Button('Fit Model')
    showerror = Button('Fit Error')
    updatemodelplot = Button('Update Model Plot')
    autoupdate = Bool(True)
    data = Array(dtype=float,shape=(2,None))
    weights = Array
    weighttype = Enum(('custom','equal','lin bins','log bins'))
    weightsvary = Property(Bool)
    weights0rem = Bool(True)
    modelselector = NewModelSelector
    ytype = Enum(('data and model','residuals'))

    zoomtool = Instance(ZoomTool)
    pantool = Instance(PanTool)

    scattertool = Enum(None,'clicktoggle','clicksingle','clickimmediate','lassoadd','lassoremove','lassoinvert')
    selectedi = Property #indecies of the selected objects
    weightchangesel = Button('Set Selection To')
    weightchangeto = Float(1.0)
    delsel = Button('Delete Selected')
    unselectonaction = Bool(True)
    clearsel = Button('Clear Selections')
    lastselaction = Str('None')

    datasymb = Button('Data Symbol...')
    modline = Button('Model Line...')

    savews = Button('Save Weights')
    loadws = Button('Load Weights')
    _savedws = Array

    plotname = Property
    updatestats = Event
    chi2 = Property(Float,depends_on='updatestats')
    chi2r = Property(Float,depends_on='updatestats')


    nmod = Int(1024)
    #modelpanel = View(Label('empty'),kind='subpanel',title='model editor')
    modelpanel = View

    panel_view = View(VGroup(
                       Item('plot', editor=ComponentEditor(),show_label=False),
                       HGroup(Item('tmodel.modelname',show_label=False,style='readonly'),
                              Item('nmod',label='Number of model points'),
                              Item('updatemodelplot',show_label=False,enabled_when='not autoupdate'),
                              Item('autoupdate',label='Auto?'))
                      ),
                    title='Model Data Fitter'
                    )


    selection_view = View(Group(
                           Item('scattertool',label='Selection Mode',
                                 editor=EnumEditor(values={None:'1:No Selection',
                                                           'clicktoggle':'3:Toggle Select',
                                                           'clicksingle':'2:Single Select',
                                                           'clickimmediate':'7:Immediate',
                                                           'lassoadd':'4:Add with Lasso',
                                                           'lassoremove':'5:Remove with Lasso',
                                                           'lassoinvert':'6:Invert with Lasso'})),
                           Item('unselectonaction',label='Clear Selection on Action?'),
                           Item('clearsel',show_label=False),
                           Item('weightchangesel',show_label=False),
                           Item('weightchangeto',label='Weight'),
                           Item('delsel',show_label=False)
                         ),title='Selection Options')

    traits_view = View(VGroup(
                        HGroup(Item('object.plot.index_scale',label='x-scaling',
                                    enabled_when='object.plot.index_mapper.range.low>0 or object.plot.index_scale=="log"'),
                              spring,
                              Item('ytype',label='y-data'),
                              Item('object.plot.value_scale',label='y-scaling',
                                   enabled_when='object.plot.value_mapper.range.low>0 or object.plot.value_scale=="log"')
                              ),
                       Item('plotcontainer', editor=ComponentEditor(),show_label=False),
                       HGroup(VGroup(HGroup(Item('weighttype',label='Weights:'),
                                            Item('savews',show_label=False),
                                            Item('loadws',enabled_when='_savedws',show_label=False)),
                                Item('weights0rem',label='Remove 0-weight points for fit?'),
                                HGroup(Item('newmodel',show_label=False),
                                       Item('fitmodel',show_label=False),
                                       Item('showerror',show_label=False,enabled_when='tmodel.lastfitfailure'),
                                       VGroup(Item('chi2',label='Chi2:',style='readonly',format_str='%6.6g',visible_when='tmodel.model is not None'),
                                             Item('chi2r',label='reduced:',style='readonly',format_str='%6.6g',visible_when='tmodel.model is not None'))
                                       )#Item('selbutton',show_label=False))
                              ,springy=False),spring,
                              VGroup(HGroup(Item('autoupdate',label='Auto?'),
                              Item('updatemodelplot',show_label=False,enabled_when='not autoupdate')),
                              Item('nmod',label='Nmodel'),
                              HGroup(Item('datasymb',show_label=False),Item('modline',show_label=False)),springy=False),springy=True),
                       '_',
                       HGroup(Item('scattertool',label='Selection Mode',
                                 editor=EnumEditor(values={None:'1:No Selection',
                                                           'clicktoggle':'3:Toggle Select',
                                                           'clicksingle':'2:Single Select',
                                                           'clickimmediate':'7:Immediate',
                                                           'lassoadd':'4:Add with Lasso',
                                                           'lassoremove':'5:Remove with Lasso',
                                                           'lassoinvert':'6:Invert with Lasso'})),
                           Item('unselectonaction',label='Clear Selection on Action?'),
                           Item('clearsel',show_label=False),
                           Item('weightchangesel',show_label=False),
                           Item('weightchangeto',label='Weight'),
                           Item('delsel',show_label=False),
                         ),#layout='flow'),
                       Item('tmodel',show_label=False,style='custom',editor=InstanceEditor(kind='subpanel'))
                      ),
                    handler=FGHandler(),
                    resizable=True,
                    title='Data Fitting',
                    buttons=['OK','Cancel'],
                    width=700,
                    height=900
                    )


    def __init__(self,xdata=None,ydata=None,weights=None,model=None,
                 include_models=None,exclude_models=None,fittype=None,**traits):
        """

        :param xdata: the first dimension of the data to be fit
        :type xdata: array-like
        :param ydata: the second dimension of the data to be fit
        :type ydata: array-like
        :param weights:
            The weights to apply to the data. Statistically interpreted as inverse
            errors (*not* inverse variance). May be any of the following forms:

            * None for equal weights
            * an array of points that must match `ydata`
            * a 2-sequence of arrays (xierr,yierr) such that xierr matches the
              `xdata` and yierr matches `ydata`
            * a function called as f(params) that returns an array of weights
              that match one of the above two conditions

        :param model: the initial model to use to fit this data
        :type model:
            None, string, or :class:`pymodelfit.core.FunctionModel1D`
            instance.
        :param include_models:
            With `exclude_models`, specifies which models should be available in
            the "new model" dialog (see `models.list_models` for syntax).
        :param exclude_models:
            With `include_models`, specifies which models should be available in
            the "new model" dialog (see `models.list_models` for syntax).
        :param fittype:
            The fitting technique for the initial fit (see
            :class:`pymodelfit.core.FunctionModel`).
        :type fittype: string

        kwargs are passed in as any additional traits to apply to the
        application.

        """

        self.modelpanel = View(Label('empty'),kind='subpanel',title='model editor')

        self.tmodel = TraitedModel(model)

        if model is not None and fittype is not None:
            self.tmodel.model.fittype = fittype

        if xdata is None or ydata is None:
            if not hasattr(self.tmodel.model,'data') or self.tmodel.model.data is None:
                raise ValueError('data not provided and no data in model')
            if xdata is None:
                xdata = self.tmodel.model.data[0]
            if ydata is None:
                ydata = self.tmodel.model.data[1]
            if weights is None:
                weights = self.tmodel.model.data[2]

        self.on_trait_change(self._paramsChanged,'tmodel.paramchange')

        self.modelselector = NewModelSelector(include_models,exclude_models)

        self.data = [xdata,ydata]


        if weights is None:
            self.weights = np.ones_like(xdata)
            self.weighttype = 'equal'
        else:
            self.weights = np.array(weights,copy=True)
            self.savews = True

        weights1d = self.weights
        while len(weights1d.shape)>1:
            weights1d = np.sum(weights1d**2,axis=0)

        pd = ArrayPlotData(xdata=self.data[0],ydata=self.data[1],weights=weights1d)
        self.plot = plot = Plot(pd,resizable='hv')

        self.scatter = plot.plot(('xdata','ydata','weights'),name='data',
                         color_mapper=_cmapblack if self.weights0rem else _cmap,
                         type='cmap_scatter', marker='circle')[0]

        self.errorplots = None

        if not isinstance(model,FunctionModel1D):
            self.fitmodel = True

        self.updatemodelplot = False #force plot update - generates xmod and ymod
        plot.plot(('xmod','ymod'),name='model',type='line',line_style='dash',color='black',line_width=2)
        del plot.x_mapper.range.sources[-1]  #remove the line plot from the x_mapper source so only the data is tied to the scaling

        self.on_trait_change(self._rangeChanged,'plot.index_mapper.range.updated')

        self.pantool = PanTool(plot,drag_button='left')
        plot.tools.append(self.pantool)
        self.zoomtool = ZoomTool(plot)
        self.zoomtool.prev_state_key = KeySpec('a')
        self.zoomtool.next_state_key = KeySpec('s')
        plot.overlays.append(self.zoomtool)

        self.scattertool = None
        self.scatter.overlays.append(ScatterInspectorOverlay(self.scatter,
                        hover_color = "black",
                        selection_color="black",
                        selection_outline_color="red",
                        selection_line_width=2))


        self.colorbar = colorbar = ColorBar(index_mapper=LinearMapper(range=plot.color_mapper.range),
                                            color_mapper=plot.color_mapper.range,
                                            plot=plot,
                                            orientation='v',
                                            resizable='v',
                                            width = 30,
                                            padding = 5)
        colorbar.padding_top = plot.padding_top
        colorbar.padding_bottom = plot.padding_bottom
        colorbar._axis.title = 'Weights'

        self.plotcontainer = container = HPlotContainer(use_backbuffer=True)
        container.add(plot)
        container.add(colorbar)

        super(FitGui,self).__init__(**traits)

        self.on_trait_change(self._scale_change,'plot.value_scale,plot.index_scale')

        if weights is not None and len(weights)==2:
            self.weightsChanged() #update error bars

    def _weights0rem_changed(self,old,new):
        if new:
            self.plot.color_mapper = _cmapblack(self.plot.color_mapper.range)
        else:
            self.plot.color_mapper = _cmap(self.plot.color_mapper.range)
        self.plot.request_redraw()
#        if old and self.filloverlay in self.plot.overlays:
#            self.plot.overlays.remove(self.filloverlay)
#        if new:
#            self.plot.overlays.append(self.filloverlay)
#        self.plot.request_redraw()

    def _paramsChanged(self):
        self.updatemodelplot = True

    def _nmod_changed(self):
        self.updatemodelplot = True

    def _rangeChanged(self):
        self.updatemodelplot = True

    #@on_trait_change('object.plot.value_scale,object.plot.index_scale',post_init=True)
    def _scale_change(self):
        self.plot.request_redraw()

    def _updatemodelplot_fired(self,new):
        #If the plot has not been generated yet, just skip the update
        if self.plot is None:
            return

        #if False (e.g. button click), update regardless, otherwise check for autoupdate
        if new and not self.autoupdate:
            return

        mod = self.tmodel.model
        if self.ytype == 'data and model':
            if mod:
                #xd = self.data[0]
                #xmod = np.linspace(np.min(xd),np.max(xd),self.nmod)
                xl = self.plot.index_range.low
                xh = self.plot.index_range.high
                if self.plot.index_scale=="log":
                    xmod = np.logspace(np.log10(xl),np.log10(xh),self.nmod)
                else:
                    xmod = np.linspace(xl,xh,self.nmod)
                ymod = self.tmodel.model(xmod)

                self.plot.data.set_data('xmod',xmod)
                self.plot.data.set_data('ymod',ymod)

            else:
                self.plot.data.set_data('xmod',[])
                self.plot.data.set_data('ymod',[])
        elif self.ytype == 'residuals':
            if mod:
                self.plot.data.set_data('xmod',[])
                self.plot.data.set_data('ymod',[])
                #residuals set the ydata instead of setting the model
                res = mod.residuals(*self.data)
                self.plot.data.set_data('ydata',res)
            else:
                self.ytype = 'data and model'
        else:
            assert True,'invalid Enum'


    def _fitmodel_fired(self):
        from warnings import warn

        preaup = self.autoupdate
        try:
            self.autoupdate = False
            xd,yd = self.data
            kwd = {'x':xd,'y':yd}
            if self.weights is not None:
                w = self.weights
                if self.weights0rem:
                    if xd.shape == w.shape:
                        m = w!=0
                        w = w[m]
                        kwd['x'] = kwd['x'][m]
                        kwd['y'] = kwd['y'][m]
                    elif np.any(w==0):
                        warn("can't remove 0-weighted points if weights don't match data")
                kwd['weights'] = w
            self.tmodel.fitdata = kwd
        finally:
            self.autoupdate = preaup

        self.updatemodelplot = True
        self.updatestats = True


#    def _tmodel_changed(self,old,new):
#        #old is only None before it is initialized
#        if new is not None and old is not None and new.model is not None:
#            self.fitmodel = True

    def _newmodel_fired(self,newval):
        from inspect import isclass

        if isinstance(newval,basestring) or isinstance(newval,FunctionModel1D) \
           or (isclass(newval) and issubclass(newval,FunctionModel1D)):
            self.tmodel = TraitedModel(newval)
        else:
            if self.modelselector.edit_traits(kind='modal').result:
                cls = self.modelselector.selectedmodelclass
                if cls is None:
                    self.tmodel = TraitedModel(None)
                elif self.modelselector.isvarargmodel:
                    self.tmodel = TraitedModel(cls(self.modelselector.modelargnum))
                    self.fitmodel = True
                else:
                    self.tmodel = TraitedModel(cls())
                    self.fitmodel = True
            else: #cancelled
                return

    def _showerror_fired(self,evt):
        if self.tmodel.lastfitfailure:
            ex = self.tmodel.lastfitfailure
            dialog = HasTraits(s=ex.__class__.__name__+': '+str(ex))
            view = View(Item('s',style='custom',show_label=False),
                        resizable=True,buttons=['OK'],title='Fitting error message')
            dialog.edit_traits(view=view)

    @cached_property
    def _get_chi2(self):
        try:
            return self.tmodel.model.chi2Data()[0]
        except:
            return 0

    @cached_property
    def _get_chi2r(self):
        try:
            return self.tmodel.model.chi2Data()[1]
        except:
            return 0

    def _get_nomodel(self):
        return self.tmodel.model is None

    def _get_weightsvary(self):
        w = self.weights
        return np.any(w!=w[0])if len(w)>0 else False

    def _get_plotname(self):
        xlabel = self.plot.x_axis.title
        ylabel = self.plot.y_axis.title
        if xlabel == '' and ylabel == '':
            return ''
        else:
            return xlabel+' vs '+ylabel
    def _set_plotname(self,val):
        if isinstance(val,basestring):
            val = val.split('vs')
            if len(val) ==1:
                val = val.split('-')
            val = [v.strip() for v in val]
        self.x_axis.title = val[0]
        self.y_axis.title = val[1]


    #selection-related
    def _scattertool_changed(self,old,new):
        if new == 'No Selection':
            self.plot.tools[0].drag_button='left'
        else:
            self.plot.tools[0].drag_button='right'
        if old is not None and 'lasso' in old:
            if new is not None and 'lasso' in new:
                #connect correct callbacks
                self.lassomode = new.replace('lasso','')
                return
            else:
                #TODO:test
                self.scatter.tools[-1].on_trait_change(self._lasso_handler,
                                            'selection_changed',remove=True)
                del self.scatter.overlays[-1]
                del self.lassomode
        elif old == 'clickimmediate':
            self.scatter.index.on_trait_change(self._immediate_handler,
                                            'metadata_changed',remove=True)

        self.scatter.tools = []
        if new is None:
            pass
        elif 'click' in new:
            smodemap = {'clickimmediate':'single','clicksingle':'single',
                        'clicktoggle':'toggle'}
            self.scatter.tools.append(ScatterInspector(self.scatter,
                                      selection_mode=smodemap[new]))
            if new == 'clickimmediate':
                self.clearsel = True
                self.scatter.index.on_trait_change(self._immediate_handler,
                                                    'metadata_changed')
        elif 'lasso' in new:
            lasso_selection = LassoSelection(component=self.scatter,
                                    selection_datasource=self.scatter.index)
            self.scatter.tools.append(lasso_selection)
            lasso_overlay = LassoOverlay(lasso_selection=lasso_selection,
                                         component=self.scatter)
            self.scatter.overlays.append(lasso_overlay)
            self.lassomode = new.replace('lasso','')
            lasso_selection.on_trait_change(self._lasso_handler,
                                            'selection_changed')
            lasso_selection.on_trait_change(self._lasso_handler,
                                            'selection_completed')
            lasso_selection.on_trait_change(self._lasso_handler,
                                            'updated')
        else:
            raise TraitsError('invalid scattertool value')

    def _weightchangesel_fired(self):
        self.weights[self.selectedi] = self.weightchangeto
        if self.unselectonaction:
            self.clearsel = True

        self._sel_alter_weights()
        self.lastselaction = 'weightchangesel'

    def _delsel_fired(self):
        self.weights[self.selectedi] = 0
        if self.unselectonaction:
            self.clearsel = True

        self._sel_alter_weights()
        self.lastselaction = 'delsel'

    def _sel_alter_weights(self):
        if self.weighttype != 'custom':
            self._customweights = self.weights
            self.weighttype = 'custom'
        self.weightsChanged()

    def _clearsel_fired(self,event):
        if isinstance(event,list):
            self.scatter.index.metadata['selections'] = event
        else:
            self.scatter.index.metadata['selections'] = list()

    def _lasso_handler(self,name,new):
        if name == 'selection_changed':
            lassomask = self.scatter.index.metadata['selection'].astype(int)
            clickmask = np.zeros_like(lassomask)
            clickmask[self.scatter.index.metadata['selections']] = 1

            if self.lassomode == 'add':
                mask = clickmask | lassomask
            elif self.lassomode == 'remove':
                mask = clickmask & ~lassomask
            elif self.lassomode == 'invert':
                mask = np.logical_xor(clickmask,lassomask)
            else:
                raise TraitsError('lassomode is in invalid state')

            self.scatter.index.metadata['selections'] = list(np.where(mask)[0])
        elif name == 'selection_completed':
            self.scatter.overlays[-1].visible = False
        elif name == 'updated':
            self.scatter.overlays[-1].visible = True
        else:
            raise ValueError('traits event name %s invalid'%name)

    def _immediate_handler(self):
        sel = self.selectedi
        if len(sel) > 1:
            self.clearsel = True
            raise TraitsError('selection error in immediate mode - more than 1 selection')
        elif len(sel)==1:
            if self.lastselaction != 'None':
                setattr(self,self.lastselaction,True)
            del sel[0]

    def _savews_fired(self):
        self._savedws = self.weights.copy()

    def _loadws_fired(self):
        self.weights = self._savedws
        self._savews_fired()

    def _get_selectedi(self):
        return self.scatter.index.metadata['selections']


    @on_trait_change('data,ytype',post_init=True)
    def dataChanged(self):
        """
        Updates the application state if the fit data are altered - the GUI will
        know if you give it a new data array, but not if the data is changed
        in-place.
        """
        pd = self.plot.data
        #TODO:make set_data apply to both simultaneously?
        pd.set_data('xdata',self.data[0])
        pd.set_data('ydata',self.data[1])

        self.updatemodelplot = False

    @on_trait_change('weights',post_init=True)
    def weightsChanged(self):
        """
        Updates the application state if the weights/error bars for this model
        are changed - the GUI will automatically do this if you give it a new
        set of weights array, but not if they are changed in-place.
        """
        weights = self.weights
        if 'errorplots' in self.trait_names():
            #TODO:switch this to updating error bar data/visibility changing
            if self.errorplots is not None:
                self.plot.remove(self.errorplots[0])
                self.plot.remove(self.errorplots[1])
                self.errorbarplots = None

            if len(weights.shape)==2 and weights.shape[0]==2:
                xerr,yerr = 1/weights

                high = ArrayDataSource(self.scatter.index.get_data()+xerr)
                low = ArrayDataSource(self.scatter.index.get_data()-xerr)
                ebpx = ErrorBarPlot(orientation='v',
                                   value_high = high,
                                   value_low = low,
                                   index = self.scatter.value,
                                   value = self.scatter.index,
                                   index_mapper = self.scatter.value_mapper,
                                   value_mapper = self.scatter.index_mapper
                                )
                self.plot.add(ebpx)

                high = ArrayDataSource(self.scatter.value.get_data()+yerr)
                low = ArrayDataSource(self.scatter.value.get_data()-yerr)
                ebpy = ErrorBarPlot(value_high = high,
                                   value_low = low,
                                   index = self.scatter.index,
                                   value = self.scatter.value,
                                   index_mapper = self.scatter.index_mapper,
                                   value_mapper = self.scatter.value_mapper
                                )
                self.plot.add(ebpy)

                self.errorplots = (ebpx,ebpy)

        while len(weights.shape)>1:
            weights = np.sum(weights**2,axis=0)
        self.plot.data.set_data('weights',weights)
        self.plot.plots['data'][0].color_mapper.range.refresh()

        if self.weightsvary:
            if self.colorbar not in self.plotcontainer.components:
                self.plotcontainer.add(self.colorbar)
                self.plotcontainer.request_redraw()
        elif self.colorbar in self.plotcontainer.components:
                self.plotcontainer.remove(self.colorbar)
                self.plotcontainer.request_redraw()


    def _weighttype_changed(self, name, old, new):
        if old == 'custom':
            self._customweights = self.weights

        if new == 'custom':
            self.weights = self._customweights #if hasattr(self,'_customweights') else np.ones_like(self.data[0])
        elif new == 'equal':
            self.weights = np.ones_like(self.data[0])
        elif new == 'lin bins':
            self.weights = binned_weights(self.data[0],10,False)
        elif new == 'log bins':
            self.weights = binned_weights(self.data[0],10,True)
        else:
            raise TraitError('Invalid Enum value on weighttype')

    def getModelInitStr(self):
        """
        Generates a python code string that can be used to generate a model with
        parameters matching the model in this :class:`FitGui`.

        :returns: initializer string

        """
        mod = self.tmodel.model
        if mod is None:
            return 'None'
        else:
            parstrs = []
            for p,v in mod.pardict.iteritems():
                parstrs.append(p+'='+str(v))
            if mod.__class__._pars is None: #varargs need to have the first argument give the right number
                varcount = len(mod.params)-len(mod.__class__._statargs)
                parstrs.insert(0,str(varcount))
            return '%s(%s)'%(mod.__class__.__name__,','.join(parstrs))

    def getModelObject(self):
        """
        Gets the underlying object representing the model for this fit.

        :returns: The :class:`pymodelfit.core.FunctionModel1D` object.
        """
        return self.tmodel.model