def do_proj(self): cons_view = View(Item('axis'), Item('field', editor=EnumEditor(name='field_list')), Item('weight_field', editor=EnumEditor(name='none_field_list')), buttons=OKCancelButtons, title="Projector: %s" % self.pf) ps = ProjPlotSpec(pf=self.pf) hand = PlotCreationHandler(main_window=mw, pnode=self, model=ps, plot_type=ProjPlotTab, format="Proj: %s") ps.edit_traits(cons_view, handler=hand)
def _get_table_editor ( self, names ): """ Returns a table editor to use for editing the filter. """ from enthought.traits.ui.api import TableEditor names = self._object.editable_traits() name_editor = EnumEditor( values = names ) if len( self.rules ) == 0: self.rules = [ GenericTableFilterRule( filter = self, name_editor = name_editor ).set( name = name ) for name in names ] for rule in self.rules: rule.enabled = False return TableEditor( columns = menu_table_filter_rule_columns, orientation = 'vertical', deletable = True, sortable = False, configurable = False, auto_size = False, auto_add = True, row_factory = GenericTableFilterRule, row_factory_kw = { 'filter': self, 'name_editor': name_editor } )
def default_traits_view(self): if self.model is None: g = Group() g.content.append(Label('No Model Selected')) else: #g = Group(label=self.modelname,show_border=False,orientation='horizontal',layout='flow') g = Group(label=self.modelname,show_border=True,orientation='vertical') hg = HGroup(Item('fittype',label='Fit Technique', editor=EnumEditor(name='fittypes'))) g.content.append(hg) gp = HGroup(scrollable=True) for p in self.model.params: gi = Group(orientation='horizontal',label=p) self.add_trait(p,Float) setattr(self,p,getattr(self.model,p)) self.on_trait_change(self._param_change_handler,p) gi.content.append(Item(p,show_label=False)) ffp = 'fixfit_'+p self.add_trait(ffp,Bool) #default to fixed if the paramtere is a class-level fixed model setattr(self,ffp,p in self.model.__class__.fixedpars) self.on_trait_change(self._param_change_handler,ffp) gi.content.append(Item(ffp,label='Fix?')) gp.content.append(gi) g.content.append(gp) return View(g,buttons=['Apply','Revert','OK','Cancel'])
class NewModelSelector(HasTraits): modelnames = List selectedname = Str('No Model') modelargnum = Int(2) selectedmodelclass = Property isvarargmodel = Property(depends_on='modelnames') traits_view = View(Item('selectedname',label='Model Name:',editor=EnumEditor(name='modelnames')), Item('modelargnum',label='Extra Parameters:',enabled_when='isvarargmodel'), buttons=['OK','Cancel']) def __init__(self,include_models=None,exclude_models=None,**traits): super(NewModelSelector,self).__init__(**traits) self.modelnames = list_models(include_models,exclude_models,FunctionModel1D) self.modelnames.insert(0,'No Model') self.modelnames.sort() def _get_selectedmodelclass(self): n = self.selectedname if n == 'No Model': return None else: return get_model_class(n) def _get_isvarargmodel(self): cls = self.selectedmodelclass if cls is None: return False else: return cls.isVarnumModel()
class DataFitterPanel(DataFitter): """Fitter panel object, for data fitting with gui for fit function selection >>> import numpy >>> x = numpy.linspace(0,100) >>> y = numpy.linspace(0,100) + numpy.random.randn(50) >>> data = FitData(x = x, y = y) >>> t = DataFitter(data = data) >>> p,c = t.fit() """ category = Enum(list(CATEGORIES.keys())) function_names = Property(List(Str), depends_on='category') function_name = Str description = DelegatesTo('function') view = View(Group('category', Item(name='function_name', editor=EnumEditor(name='function_names'), id='function_name_edit'), Item('description', style='custom'), label='Fit Function'), data_fitter_group, resizable=False) def _function_name_changed(self, name): self.function.function = getattr(CATEGORIES[self.category], name) def _get_function_names(self): function_names = CATEGORIES[self.category].FUNCTIONS self.function_name = function_names[0] return function_names
class RemapDemo(HasTraits): surf_func = Str() func_list = List([ "np.sqrt(8- x**2 - y**2)", "np.sin(6*np.sqrt(x**2+y**2))", "np.sin(6*x)", "np.sin(6*y)", "np.sin(np.sqrt(x**2+y**2))/np.sqrt(x**2+y**2)", ]) range = Range(1.0, 100.0) view_height = Range(1.0, 50.0, 10.0) grid = Bool(True) view = View(Item("surf_func", label="曲面函数", editor=EnumEditor(name="func_list", auto_set=False, evaluate=lambda x: x)), Item("range", label="曲面范围"), Item("view_height", label="视点高度"), Item("grid", label="显示网格"), title="Remap Demo控制面板") def __init__(self, *args, **kwargs): super(RemapDemo, self).__init__(*args, **kwargs) self.img = cv.imread("lena.jpg") self.size = self.img.size() self.w, self.h = self.size.width, self.size.height self.dstimg = cv.Mat() self.map1 = cv.Mat(self.size, cv.CV_32FC1) self.map2 = cv.Mat(self.size, cv.CV_32FC1) self.gridimg = self.make_grid_img() self.on_trait_change(self.redraw, "surf_func,range,view_height,grid") def redraw(self): def func(x, y): return eval(self.surf_func, globals(), locals()) try: self.map1[:], self.map2[:] = make_surf_map(func, self.range, self.w, self.h, self.view_height) except SyntaxError: return if self.grid: img = self.gridimg else: img = self.img cv.remap(img, self.dstimg, self.map1, self.map2, cv.INTER_LINEAR) cv.imshow("Remap Demo", self.dstimg) def make_grid_img(self): img = self.img.clone() for i in range(0, self.w, 30): cv.line(img, cv.Point(i, 0), cv.Point(i, self.h), cv.CV_RGB(0, 0, 0), 1) for i in range(0, self.h, 30): cv.line(img, cv.Point(0, i), cv.Point(self.w, i), cv.CV_RGB(0, 0, 0), 1) return img
class TestEnumEditor(HasTraits): #--------------------------------------------------------------------------- # Trait definitions: #--------------------------------------------------------------------------- value = Trait(1, enum, range, editor=EnumEditor(values=values, evaluate=int))
def create_editor ( self): #from enthought.traits.ui.api import EnumEditor #print dict(self.wrdict.items()) ed = EnumEditor( values = self, cols = self.cols or 3, evaluate = self.evaluate, mode = self.mode or 'radio' ) return ed
class Equalizers(HasTraits): eqs = List(Equalizer, [Equalizer()]) h = Array(dtype=np.complex, transient=True) # Equalizer列表eqs的编辑器定义 table_editor = TableEditor(columns=[ ObjectColumn(name="freq", width=0.4, style="readonly"), ObjectColumn(name="Q", width=0.3, style="readonly"), ObjectColumn(name="gain", width=0.3, style="readonly"), ], deletable=True, sortable=True, auto_size=False, show_toolbar=True, edit_on_first_click=False, orientation='vertical', edit_view=View(Group( Item("freq", editor=EnumEditor(values=EQ_FREQS)), Item("freq", editor=scrubber(1.0)), Item("Q", editor=scrubber(0.01)), Item("gain", editor=scrubber(0.1)), show_border=True, ), resizable=True), row_factory=Equalizer) view = View(Item("eqs", show_label=False, editor=table_editor), width=0.25, height=0.5, resizable=True) @on_trait_change("eqs.h") def recalculate_h(self): '''计算多组均衡器级联时的频率响应''' try: tmp = np.array([ eq.h for eq in self.eqs if eq.h != None and len(eq.h) == len(W) ]) self.h = np.prod(tmp, axis=0) except: pass def export(self, path): '''将均衡器的系数输出为C语言文件''' f = file(path, "w") f.write("double EQ_PARS[][5] = {\n") f.write("//b0,b1,b2,a1,a2 // frequency, Q, gain\n") for eq in self.eqs: eq.export_parameters(f) f.write("};\n") f.close()
class EnumExample(HasTraits): priority = Enum('Medium', 'Highest', 'High', 'Medium', 'Low', 'Lowest') view = View( Item(name='priority', editor=EnumEditor( values={ 'Highest': '1:Highest', 'High': '2:High', 'Medium': '3:Medium', 'Low': '4:Low', 'Lowest': '5:Lowest', })))
def do_slice(self): cons_view = View(Item('axis'), Item('center'), Item('field', editor=EnumEditor(name='field_list')), buttons=OKCancelButtons, title="Slicer: %s" % self.pf) ps = SlicePlotSpec(pf=self.pf) hand = PlotCreationHandler(main_window=mw, pnode=self, model=ps, plot_type=SlicePlotTab, format="Slice: %s") ps.edit_traits(cons_view, handler=hand)
class TableTest(HasStrictTraits): #--------------------------------------------------------------------------- # Trait definitions: #--------------------------------------------------------------------------- #people = Instance( Person ) people = List(Person) #--------------------------------------------------------------------------- # Traits view definitions: #--------------------------------------------------------------------------- _valid_states = List(["AL", "AR", "AZ", "AK"]) _state_editor = EnumEditor(name="_valid_states", evaluate=evaluate_value, object='table_editor_object') table_editor = TableEditor(columns=[ ObjectColumn(name='name'), ObjectColumn(name='age'), ObjectColumn(name='phone'), ObjectColumn(name='state', editor=_state_editor), ], editable=True, deletable=True, sortable=True, sort_model=True, show_lines=True, orientation='vertical', show_column_labels=True, edit_view=View( ['name', 'age', 'phone', 'state', '|[]'], resizable=True), filter=None, filters=filters, row_factory=Person) traits_view = View( [Item('people', id='people', editor=table_editor), '|[]<>'], title='Table Editor Test', id='enthought.traits.ui.tests.table_editor_test', dock='horizontal', width=.4, height=.3, resizable=True, buttons=NoButtons, kind='live')
def default_traits_view(self): view = View( VGroup( HGroup( Item("current_map", label=u"颜色映射", editor=EnumEditor(name="object.color_maps")), Item("reverse_map", label=u"反转颜色"), Item("position", label=u"位置", style="readonly"), ), Item("plot", show_label=False, editor=ComponentEditor()), ), resizable = True, width = 550, height = 300, title = u"Mandelbrot观察器" ) return view
def createView(self): """Set up a view for the traits.""" indexItems = [ Item('index0', editor=EnumEditor(name='index0_enum'), visible_when=self.runtimeElementCondition + self.numDimensionsString + '>=1'), Item('index1', editor=EnumEditor(name='index1_enum'), visible_when=self.runtimeElementCondition + self.numDimensionsString + '>=2'), Item('index2', editor=EnumEditor(name='index2_enum'), visible_when=self.runtimeElementCondition + self.numDimensionsString + '>=3'), ] items = [Item('item')] #items = [Item('item')] + indexItems[:self.getMaxDimensions()] self.view = View(Group( *items, **dict(orientation='horizontal', show_labels=False)), buttons=NoButtons, handler=SelectorHandler)
def _get_table_editor ( self, names ): """ Returns a table editor to use for editing the filter. """ from enthought.traits.ui.api import TableEditor return TableEditor( columns = generic_table_filter_rule_columns, orientation = 'vertical', deletable = True, sortable = False, configurable = False, auto_size = False, auto_add = True, row_factory = GenericTableFilterRule, row_factory_kw = { 'filter': self, 'name_editor': EnumEditor( values = names ) } )
class Address(HasTraits): """ Demo class for demonstrating dynamic redefinition of valid trait values. """ street_address = Str st = Enum(cities.keys()[0], cities.keys()) city = Str view = View(Item(name='street_address'), Item(name='st', label='State'), Item(name='city', editor=EnumEditor(name='handler.cities'), id='cityedit'), title='Address Information', buttons=['OK'], resizable=True, handler=AddressHandler)
class Constraint(HasTraits): name = String runcase = Any constraint_variables_available = Property(List(String), depends_on='runcase.constraint_variables') @cached_property def _get_constraint_variables_available(self): return self.runcase.constraint_variables.keys() #constraint = Instance(ConstraintVariable) constraint_name = String value = Float pattern = String cmd = String editor = TableEditor( auto_size=False, columns=[ ObjectColumn(name='name', editable=False, label='Parameter'), ObjectColumn(name='constraint_name', label='Constraint', editor=EnumEditor(name='constraint_variables_available')), ObjectColumn(name='value', label='Value', editor=TextEditor(evaluate=float, enter_set=True, auto_set=False)) ])
class TrimCaseConfig(RunOptionsConfig): trimcase = Instance(TrimCase, TrimCase()) #runcase = DelegatesTo('trimcase') varying_param = String('velocity') varying_expr = Expression('100') param_name_list = Property(List(String), depends_on='trimcase.parameters') @cached_property def _get_param_name_list(self): return sorted(self.trimcase.parameters.keys(), key=lambda x: x.lower()) @on_trait_change('trimcase.parameter_view.value,trimcase.type') def on_trimcase_changed(self, object, name, old, new): print 'trimcase_changed' if name == 'value': self.trimcase.runcase.update_trim_case(self.trimcase, object) self.trimcase.update_parameters_from_avl() elif name == 'type': self.trimcase.update_parameters_from_avl() self.param_name_list[0] # send changed value of trimcase to avl, parameter is Parameter instance def update_trim_case(self, trimcase, parameter): print 'update_trim_case' self.get_parameters_info_from_avl(self.avl) self.avl.sendline('oper') self.avl.sendline(trimcase.type_) self.avl.expect(AVL.patterns['/oper/m']) #print self.parameters.keys() #for p, v in trimcase.parameters.iteritems(): self.avl.sendline('%s %f' % (parameter.cmd, parameter.value)) AVL.goto_state(self.avl) custom_group = Group( Item('object.trimcase.type'), Group(Item('object.trimcase.parameter_view', editor=Parameter.editor, show_label=False, height=0.4), label='Parameters', scrollable=True), #Item('trimcase', style='custom', show_label=False), Group(Item('varying_param', editor=EnumEditor(name='param_name_list')), Item('varying_expr', label='Expression')))
class AnimatedGIFDemo(HasTraits): # The animated GIF file to display: gif_file = File(files[0]) # Is the animation playing or not? playing = Bool(True) # The traits view: view = View(VGroup( HGroup( Item('gif_file', editor=AnimatedGIFEditor(playing='playing'), show_label=False), Item('playing'), ), '_', Item('gif_file', label='GIF File', editor=EnumEditor(values=files))), title='Animated GIF Demo', buttons=['OK'])
class Sp4ArrayFileSource(ArraySource): file_name = Str component = Enum('Amp', 'Phase', 'Real', 'Imag') set_component = Str f = Instance(sp4.Sp4File) view = View(Group(Item(name='transpose_input_array'), Item(name='file_name', editor=TextEditor(auto_set=False,\ enter_set=True)), Item(name='component', editor=EnumEditor(values=component)), Item(name='scalar_name'), Item(name='vector_name'), Item(name='spacing'), Item(name='origin'), Item(name='update_image_data', show_label=False), show_labels=True) ) def __init__(self, **kw_args): super(Sp4ArrayFileSource, self).__init__(**kw_args) fn = kw_args.pop('file_name', None) if fn is not None: self.file_name = fn self._open(self.file_name) self.component = "Amp" self._component_changed('Amp') def _open(self, fn): self.f = sp4.Sp4File(self.file_name) def _component_changed(self, info): if info == "Amp": self.scalar_data = numpy.abs(self.f.GetArray()) if info == "Phase": self.scalar_data = numpy.angle(self.f.GetArray()) self.update() def _file_name_changed(self, info): print self.file_name
class ConstraintConfig(HasTraits): constraint = Instance(Constraint) expr = Expression() def __init__(self, *args, **kwargs): super(ConstraintConfig, self).__init__(*args, **kwargs) self.expr = str(self.constraint.value) constraint_variables_available = DelegatesTo('constraint') editor = TableEditor( auto_size=False, columns=[ ObjectColumn(name='constraint.name', editable=False, label='Parameter'), ObjectColumn( name='constraint.constraint_name', label='Constraint', editor=EnumEditor(name='constraint_variables_available')), ObjectColumn(name='expr', label='Expression') ])
class HierarchyImporter(HasTraits): pf = Any min_grid_level = Int(0) max_level = Int(1) number_of_levels = Range(0, 13) max_import_levels = Property(depends_on='min_grid_level') field = Str("Density") field_list = List center_on_max = Bool(True) center = CArray(shape=(3, ), dtype='float64') cache = Bool(True) smoothed = Bool(True) show_grids = Bool(True) def _field_list_default(self): fl = self.pf.h.field_list df = self.pf.h.derived_field_list fl.sort() df.sort() return fl + df default_view = View( Item('min_grid_level', editor=RangeEditor(low=0, high_name='max_level')), Item('number_of_levels', editor=RangeEditor(low=1, high_name='max_import_levels')), Item('field', editor=EnumEditor(name='field_list')), Item('center_on_max'), Item('center', enabled_when='not object.center_on_max'), Item('smoothed'), Item('cache', label='Pre-load data'), Item('show_grids'), buttons=OKCancelButtons) def _center_default(self): return [0.5, 0.5, 0.5] @cached_property def _get_max_import_levels(self): return min(13, self.pf.h.max_level - self.min_grid_level + 1)
class Graph(HasTraits): """ 绘图组件,包括左边的数据选择控件和右边的绘图控件 """ name = Str # 绘图名,显示在标签页标题和绘图标题中 data_source = Instance(DataSource) # 保存数据的数据源 figure = Instance(Figure) # 控制绘图控件的Figure对象 selected_xaxis = Str # X轴所用的数据名 selected_items = List # Y轴所用的数据列表 clear_button = Button(u"清除") # 快速清除Y轴的所有选择的数据 view = View( HSplit( # HSplit分为左右两个区域,中间有可调节宽度比例的调节手柄 # 左边为一个组 VGroup( Item("name"), # 绘图名编辑框 Item("clear_button"), # 清除按钮 Heading(u"X轴数据"), # 静态文本 # X轴选择器,用EnumEditor编辑器,即ComboBox控件,控件中的候选数据从 # data_source的names属性得到 Item("selected_xaxis", editor= EnumEditor(name="object.data_source.names", format_str=u"%s")), Heading(u"Y轴数据"), # 静态文本 # Y轴选择器,由于Y轴可以多选,因此用CheckBox列表编辑,按两列显示 Item("selected_items", style="custom", editor=CheckListEditor(name="object.data_source.names", cols=2, format_str=u"%s")), show_border = True, # 显示组的边框 scrollable = True, # 组中的控件过多时,采用滚动条 show_labels = False # 组中的所有控件都不显示标签 ), # 右边绘图控件 Item("figure", editor=MPLFigureEditor(), show_label=False, width=600) ) ) def _name_changed(self): """ 当绘图名发生变化时,更新绘图的标题 """ axe = self.figure.axes[0] axe.set_title(self.name) self.figure.canvas.draw() def _clear_button_fired(self): """ 清除按钮的事件处理 """ self.selected_items = [] self.update() def _figure_default(self): """ figure属性的缺省值,直接创建一个Figure对象 """ figure = Figure() figure.add_axes([0.1, 0.1, 0.85, 0.80]) #添加绘图区域,四周留有边距 return figure def _selected_items_changed(self): """ Y轴数据选择更新 """ self.update() def _selected_xaxis_changed(self): """ X轴数据选择更新 """ self.update() def update(self): """ 重新绘制所有的曲线 """ axe = self.figure.axes[0] axe.clear() try: xdata = self.data_source.data[self.selected_xaxis] except: return for field in self.selected_items: axe.plot(xdata, self.data_source.data[field], label=field) axe.set_xlabel(self.selected_xaxis) axe.set_title(self.name) axe.legend() self.figure.canvas.draw()
# -*- coding: utf-8 -*-
class TVTKClassChooser(HasTraits): # The selected object, is None if no valid class_name was made. object = Property # The TVTK class name to choose. class_name = Str('', desc='class name of TVTK class (case sensitive)') # The string to search for in the class docs -- the search supports # 'and' and 'or' keywords. search = Str('', desc='string to search in TVTK class documentation '\ 'supports the "and" and "or" keywords. '\ 'press <Enter> to start search. '\ 'This is case insensitive.') clear_search = Button # The class documentation. doc = Str(_search_help_doc) # Completions for the choice of class. completions = List(Str) # List of available class names as strings. available = List(TVTK_CLASSES) ######################################## # Private traits. finder = Instance(DocSearch) n_completion = Int(25) ######################################## # View related traits. view = View(Group(Item(name='class_name', editor=EnumEditor(name='available')), Item(name='class_name', has_focus=True ), Item(name='search', editor=TextEditor(enter_set=True, auto_set=False) ), Item(name='clear_search', show_label=False), Item('_'), Item(name='completions', editor=ListEditor(columns=3), style='readonly' ), Item(name='doc', resizable=True, label='Documentation', style='custom') ), id='tvtk_doc', resizable=True, width=800, height=600, title='TVTK class chooser', buttons = ["OK", "Cancel"] ) ###################################################################### # `object` interface. ###################################################################### def __init__(self, **traits): super(TVTKClassChooser, self).__init__(**traits) self._orig_available = list(self.available) ###################################################################### # Non-public interface. ###################################################################### def _get_object(self): o = None if len(self.class_name) > 0: try: o = getattr(tvtk, self.class_name)() except (AttributeError, TypeError): pass return o def _class_name_changed(self, value): av = self.available comp = [x for x in av if x.startswith(value)] self.completions = comp[:self.n_completion] if len(comp) == 1 and value != comp[0]: self.class_name = comp[0] o = self.object if o is not None: self.doc = get_tvtk_class_doc(o) else: self.doc = _search_help_doc def _finder_default(self): return DocSearch() def _clear_search_fired(self): self.search = '' def _search_changed(self, value): if len(value) < 3: self.available = self._orig_available return f = self.finder result = f.search(value) if len(result) == 0: self.available = self._orig_available elif len(result) == 1: self.class_name = result[0] else: self.available = result self.completions = result[:self.n_completion]
class IFSDesigner(HasTraits): plot = Instance(Plot) clear = Bool(False) draw = Bool(False) timer = Instance(Timer) ifs_names = List() ifs_points = List() current_name = Str() save_button = Button(u"保存当前IFS") unsave_button = Button(u"删除当前IFS") view = View( HGroup( Item("current_name", editor = EnumEditor(name="object.ifs_names")), Item("save_button"), Item("unsave_button"), show_labels=False ), Item("plot", editor=ComponentEditor(),show_label=False), resizable=True, width = 500, height = 500, title = u"IFS图形设计器" ) def __init__(self): self.data = ArrayPlotData() self.set_empty_data() self.plot = Plot(self.data, padding=10) scatter = self.plot.plot(("x","y", "c"), type="cmap_scatter", marker_size=1, color_mapper=make_color_map(), line_width=0)[0] self.plot.x_grid.visible = False self.plot.y_grid.visible = False self.plot.x_axis.visible = False self.plot.y_axis.visible = False self.tool = TrianglesTool(self.plot) self.plot.overlays.append(self.tool) try: with file("ifs_chaco.data","rb") as f: tmp = pickle.load(f) self.ifs_names = [x[0] for x in tmp] self.ifs_points = [np.array(x[1]) for x in tmp] if len(self.ifs_names) > 0: self.current_name = self.ifs_names[-1] except: pass self.tool.on_trait_change(self.triangle_changed, 'changed') self.timer = Timer(10, self.ifs_calculate) def set_empty_data(self): self.data["x"] = np.array([]) self.data["y"] = np.array([]) self.data["c"] = np.array([]) def triangle_changed(self): count = len(self.tool.points) if count % 3 == 0: self.set_empty_data() if count < 9: self.draw = False if count >= 9 and count % 3 == 0: self.clear = True def ifs_calculate(self): if self.clear == True: self.clear = False self.initpos = [0, 0] # 不绘制迭代的初始100个点 x, y, c = ifs( self.tool.get_areas(), self.tool.get_eqs(), self.initpos, 100) self.initpos = [x[-1], y[-1]] self.draw = True if self.draw and len(self.data["x"]) < ITER_COUNT * ITER_TIMES: x, y, c = ifs( self.tool.get_areas(), self.tool.get_eqs(), self.initpos, ITER_COUNT) ox, oy, oc = self.data["x"], self.data["y"], self.data["c"] if np.max(np.abs(x)) < 1000000 and np.max(np.abs(y)) < 1000000: self.initpos = [x[-1], y[-1]] x, y, z = np.hstack((ox, x)), np.hstack((oy, y)), np.hstack((oc, c)) self.data["x"], self.data["y"], self.data["c"] = x, y, z # 调整绘图范围,保持X-Y轴的比例为1:1 xmin, xmax = np.min(x), np.max(x) ymin, ymax = np.min(y), np.max(y) xptp, yptp = xmax - xmin, ymax-ymin xcenter, ycenter =(xmax + xmin) / 2.0 , (ymax + ymin) / 2.0 w, h = float(self.plot.width), float(self.plot.height) scale = max(xptp/w , yptp/h) self.plot.index_range.low = xcenter - 0.5*scale*w self.plot.index_range.high = xcenter + 0.5*scale*w self.plot.value_range.low = ycenter - 0.5*scale*h self.plot.value_range.high = ycenter + 0.5*scale def _current_name_changed(self): index = self.ifs_names.index(self.current_name) self.tool.points = list(self.ifs_points[index]) self.tool.changed = True self.clear = True def _save_button_fired(self): """ 保存按钮处理 """ ask = AskName(name = self.current_name) if ask.configure_traits(): if ask.name not in self.ifs_names: self.ifs_names.append( ask.name ) self.ifs_points.append( self.tool.points[:] ) else: index = self.ifs_names.index(ask.name) self.ifs_names[index] = ask.name self.ifs_points[index] = self.tool.points[:] self.save_data() self.current_name = ask.name def _unsave_button_fired(self): index = self.ifs_names.index(self.current_name) del self.ifs_names[index] del self.ifs_points[index] if index >= self.ifs_names[index]: index -= 1 self.current_name = self.ifs_names[index] self.save_data() def save_data(self): with file("IFS_chaco.data", "wb") as f: pickle.dump(zip(self.ifs_names, self.ifs_points), f)
def make_editor(trait=None): return EnumEditor(name=trait.values_name)
class ParticleArrayHelper(HasTraits): """ This class manages a particle array and sets up the necessary plotting related information for it. """ # The particle array we manage. particle_array = Instance(ParticleArray) # The name of the particle array. name = Str # The active scalar to view. scalar = Str('rho', desc='name of the active scalar to view') # The mlab plot for this particle array. plot = Instance(PipelineBase) # List of available scalars in the particle array. scalar_list = List(Str) scene = Instance(MlabSceneModel) # Sync'd trait with the scalar lut manager. show_legend = Bool(False, desc='if the scalar legend is to be displayed') # Sync'd trait with the dataset to turn on/off visibility. visible = Bool(True, desc='if the particle array is to be displayed') ######################################## # View related code. view = View( Item(name='name', show_label=False, editor=TitleEditor()), Group( Item(name='visible'), Item(name='scalar', editor=EnumEditor(name='scalar_list')), Item(name='show_legend'), ), ) ###################################################################### # Private interface. ###################################################################### def _particle_array_changed(self, pa): self.name = pa.name # Setup the scalars. self.scalar_list = sorted(pa.properties.keys()) # Update the plot. x, y, z, u, v, w = pa.x, pa.y, pa.z, pa.u, pa.v, pa.w s = getattr(pa, self.scalar) p = self.plot mlab = self.scene.mlab if p is None: src = mlab.pipeline.vector_scatter(x, y, z, u, v, w, scalars=s) p = mlab.pipeline.glyph(src, mode='point', scale_mode='none') p.actor.property.point_size = 3 p.mlab_source.dataset.point_data.scalars.name = self.scalar scm = p.module_manager.scalar_lut_manager scm.set(show_legend=self.show_legend, use_default_name=False, data_name=self.scalar) self.sync_trait('visible', p.mlab_source.m_data, mutual=True) self.sync_trait('show_legend', scm, mutual=True) #set_arrays(p.mlab_source.m_data, pa) self.plot = p else: if len(x) == len(p.mlab_source.x): p.mlab_source.set(x=x, y=y, z=z, scalars=s, u=u, v=v, w=w) else: p.mlab_source.reset(x=x, y=y, z=z, scalars=s, u=u, v=v, w=w) def _scalar_changed(self, value): p = self.plot if p is not None: p.mlab_source.scalars = getattr(self.particle_array, value) p.module_manager.scalar_lut_manager.data_name = value
class DataSourceWizardView(DataSourceWizard): #---------------------------------------------------------------------- # Private traits #---------------------------------------------------------------------- _top_label = Str('Describe your data') _info_text = Str('Array size do not match') _array_label = Str('Available arrays') _data_type_text = Str("What does your data represents?" ) _lines_text = Str("Connect the points with lines" ) _scalar_data_text = Str("Array giving the value of the scalars") _optional_scalar_data_text = Str("Associate scalars with the data points") _connectivity_text = Str("Array giving the triangles") _vector_data_text = Str("Associate vector components") _position_text = Property(depends_on="position_type_") _position_text_dict = {'explicit': 'Coordinnates of the data points:', 'orthogonal grid': 'Position of the layers along each axis:', } def _get__position_text(self): return self._position_text_dict.get(self.position_type_, "") _shown_help_text = Str _data_sources_wrappers = Property(depends_on='data_sources') def _get__data_sources_wrappers(self): return [ ArrayColumnWrapper(name=name, shape=repr(self.data_sources[name].shape)) for name in self._data_sources_names ] # A traits pointing to the object, to play well with traitsUI _self = Instance(DataSourceWizard) _suitable_traits_view = Property(depends_on="data_type_") def _get__suitable_traits_view(self): return "_%s_data_view" % self.data_type_ ui = Any(False) _preview_button = Button(label='Preview structure') def __preview_button_fired(self): if self.ui: self.build_data_source() self.preview() _ok_button = Button(label='OK') def __ok_button_fired(self): if self.ui: self.ui.dispose() self.build_data_source() _cancel_button = Button(label='Cancel') def __cancel_button_fired(self): if self.ui: self.ui.dispose() _is_ok = Bool _is_not_ok = Bool def _anytrait_changed(self): """ Validates if the OK button is enabled. """ if self.ui: self._is_ok = self.check_arrays() self._is_not_ok = not self._is_ok _preview_window = Instance(PreviewWindow, ()) _info_image = Instance(ImageResource, ImageLibrary.image_resource('@std:alert16',)) #---------------------------------------------------------------------- # TraitsUI views #---------------------------------------------------------------------- _coordinates_group = \ HGroup( Item('position_x', label='x', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok')), Item('position_y', label='y', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok')), Item('position_z', label='z', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok')), ) _position_group = \ Group( Item('position_type'), Group( Item('_position_text', style='readonly', resizable=False, show_label=False), _coordinates_group, visible_when='not position_type_=="image data"', ), Group( Item('grid_shape_source_', label='Grid shape', editor=EnumEditor( name='_grid_shape_source_labels', invalid='_is_not_ok')), HGroup( spring, Item('grid_shape', style='custom', editor=ArrayEditor(width=-60), show_label=False), enabled_when='grid_shape_source==""', ), visible_when='position_type_=="image data"', ), label='Position of the data points', show_border=True, show_labels=False, ), _connectivity_group = \ Group( HGroup( Item('_connectivity_text', style='readonly', resizable=False), spring, Item('connectivity_triangles', editor=EnumEditor(name='_data_sources_names'), show_label=False, ), show_labels=False, ), label='Connectivity information', show_border=True, show_labels=False, enabled_when='position_type_=="explicit"', ), _scalar_data_group = \ Group( Item('_scalar_data_text', style='readonly', resizable=False, show_label=False), HGroup( spring, Item('scalar_data', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok')), show_labels=False, ), label='Scalar value', show_border=True, show_labels=False, ) _optional_scalar_data_group = \ Group( HGroup( 'has_scalar_data', Item('_optional_scalar_data_text', resizable=False, style='readonly'), show_labels=False, ), Item('_scalar_data_text', style='readonly', resizable=False, enabled_when='has_scalar_data', show_label=False), HGroup( spring, Item('scalar_data', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok'), enabled_when='has_scalar_data'), show_labels=False, ), label='Scalar data', show_border=True, show_labels=False, ), _vector_data_group = \ VGroup( HGroup( Item('vector_u', label='u', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok')), Item('vector_v', label='v', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok')), Item('vector_w', label='w', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok')), ), label='Vector data', show_border=True, ), _optional_vector_data_group = \ VGroup( HGroup( Item('has_vector_data', show_label=False), Item('_vector_data_text', style='readonly', resizable=False, show_label=False), ), HGroup( Item('vector_u', label='u', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok')), Item('vector_v', label='v', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok')), Item('vector_w', label='w', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok')), enabled_when='has_vector_data', ), label='Vector data', show_border=True, ), _array_view = \ View( Item('_array_label', editor=TitleEditor(), show_label=False), Group( Item('_data_sources_wrappers', editor=TabularEditor( adapter = ArrayColumnAdapter(), ), ), show_border=True, show_labels=False )) _questions_view = View( Item('_top_label', editor=TitleEditor(), show_label=False), HGroup( Item('_data_type_text', style='readonly', resizable=False), spring, 'data_type', spring, show_border=True, show_labels=False, ), HGroup( Item('_self', style='custom', editor=InstanceEditor( view_name='_suitable_traits_view'), ), Group( # FIXME: Giving up on context sensitive help # because of lack of time. #Group( # Item('_shown_help_text', editor=HTMLEditor(), # width=300, # label='Help', # ), # show_labels=False, # label='Help', #), #Group( Item('_preview_button', enabled_when='_is_ok'), Item('_preview_window', style='custom', label='Preview structure'), show_labels=False, #label='Preview structure', #), #layout='tabbed', #dock='tab', ), show_labels=False, show_border=True, ), ) _point_data_view = \ View(Group( Group(_coordinates_group, label='Position of the data points', show_border=True, ), HGroup( 'lines', Item('_lines_text', style='readonly', resizable=False), label='Lines', show_labels=False, show_border=True, ), _optional_scalar_data_group, _optional_vector_data_group, # XXX: hack to have more vertical space Label('\n'), Label('\n'), Label('\n'), )) _surface_data_view = \ View(Group( _position_group, _connectivity_group, _optional_scalar_data_group, _optional_vector_data_group, )) _vector_data_view = \ View(Group( _vector_data_group, _position_group, _optional_scalar_data_group, )) _volumetric_data_view = \ View(Group( _scalar_data_group, _position_group, _optional_vector_data_group, )) _wizard_view = View( Group( HGroup( Item('_self', style='custom', show_label=False, editor=InstanceEditor(view='_array_view'), width=0.17, ), '_', Item('_self', style='custom', show_label=False, editor=InstanceEditor(view='_questions_view'), ), ), HGroup( Item('_info_image', editor=ImageEditor(), visible_when="_is_not_ok"), Item('_info_text', style='readonly', resizable=False, visible_when="_is_not_ok"), spring, '_cancel_button', Item('_ok_button', enabled_when='_is_ok'), show_labels=False, ), ), title='Import arrays', resizable=True, ) #---------------------------------------------------------------------- # Public interface #---------------------------------------------------------------------- def __init__(self, **traits): DataSourceFactory.__init__(self, **traits) self._self = self def view_wizard(self): """ Pops up the view of the wizard, and keeps the reference it to be able to close it. """ # FIXME: Workaround for traits bug in enabled_when self.position_type_ self.data_type_ self._suitable_traits_view self.grid_shape_source self._is_ok self.ui = self.edit_traits(view='_wizard_view') def preview(self): """ Display a preview of the data structure in the preview window. """ self._preview_window.clear() self._preview_window.add_source(self.data_source) data = lambda name: self.data_sources[name] g = Glyph() g.glyph.glyph_source.glyph_source = \ g.glyph.glyph_source.glyph_list[0] g.glyph.scale_mode = 'data_scaling_off' if not (self.has_vector_data or self.data_type_ == 'vector'): g.glyph.glyph_source.glyph_source.glyph_type = 'cross' g.actor.property.representation = 'points' g.actor.property.point_size = 3. self._preview_window.add_module(g) if not self.data_type_ in ('point', 'vector') or self.lines: s = Surface() s.actor.property.opacity = 0.3 self._preview_window.add_module(s) if not self.data_type_ == 'point': self._preview_window.add_filter(ExtractEdges()) s = Surface() s.actor.property.opacity = 0.2 self._preview_window.add_module(s)
class FitGui(HasTraits): """ This class represents the fitgui application state. """ plot = Instance(Plot) colorbar = Instance(ColorBar) plotcontainer = Instance(HPlotContainer) tmodel = Instance(TraitedModel,allow_none=False) nomodel = Property newmodel = Button('New Model...') fitmodel = Button('Fit Model') showerror = Button('Fit Error') updatemodelplot = Button('Update Model Plot') autoupdate = Bool(True) data = Array(dtype=float,shape=(2,None)) weights = Array weighttype = Enum(('custom','equal','lin bins','log bins')) weightsvary = Property(Bool) weights0rem = Bool(True) modelselector = NewModelSelector ytype = Enum(('data and model','residuals')) zoomtool = Instance(ZoomTool) pantool = Instance(PanTool) scattertool = Enum(None,'clicktoggle','clicksingle','clickimmediate','lassoadd','lassoremove','lassoinvert') selectedi = Property #indecies of the selected objects weightchangesel = Button('Set Selection To') weightchangeto = Float(1.0) delsel = Button('Delete Selected') unselectonaction = Bool(True) clearsel = Button('Clear Selections') lastselaction = Str('None') datasymb = Button('Data Symbol...') modline = Button('Model Line...') savews = Button('Save Weights') loadws = Button('Load Weights') _savedws = Array plotname = Property updatestats = Event chi2 = Property(Float,depends_on='updatestats') chi2r = Property(Float,depends_on='updatestats') nmod = Int(1024) #modelpanel = View(Label('empty'),kind='subpanel',title='model editor') modelpanel = View panel_view = View(VGroup( Item('plot', editor=ComponentEditor(),show_label=False), HGroup(Item('tmodel.modelname',show_label=False,style='readonly'), Item('nmod',label='Number of model points'), Item('updatemodelplot',show_label=False,enabled_when='not autoupdate'), Item('autoupdate',label='Auto?')) ), title='Model Data Fitter' ) selection_view = View(Group( Item('scattertool',label='Selection Mode', editor=EnumEditor(values={None:'1:No Selection', 'clicktoggle':'3:Toggle Select', 'clicksingle':'2:Single Select', 'clickimmediate':'7:Immediate', 'lassoadd':'4:Add with Lasso', 'lassoremove':'5:Remove with Lasso', 'lassoinvert':'6:Invert with Lasso'})), Item('unselectonaction',label='Clear Selection on Action?'), Item('clearsel',show_label=False), Item('weightchangesel',show_label=False), Item('weightchangeto',label='Weight'), Item('delsel',show_label=False) ),title='Selection Options') traits_view = View(VGroup( HGroup(Item('object.plot.index_scale',label='x-scaling', enabled_when='object.plot.index_mapper.range.low>0 or object.plot.index_scale=="log"'), spring, Item('ytype',label='y-data'), Item('object.plot.value_scale',label='y-scaling', enabled_when='object.plot.value_mapper.range.low>0 or object.plot.value_scale=="log"') ), Item('plotcontainer', editor=ComponentEditor(),show_label=False), HGroup(VGroup(HGroup(Item('weighttype',label='Weights:'), Item('savews',show_label=False), Item('loadws',enabled_when='_savedws',show_label=False)), Item('weights0rem',label='Remove 0-weight points for fit?'), HGroup(Item('newmodel',show_label=False), Item('fitmodel',show_label=False), Item('showerror',show_label=False,enabled_when='tmodel.lastfitfailure'), VGroup(Item('chi2',label='Chi2:',style='readonly',format_str='%6.6g',visible_when='tmodel.model is not None'), Item('chi2r',label='reduced:',style='readonly',format_str='%6.6g',visible_when='tmodel.model is not None')) )#Item('selbutton',show_label=False)) ,springy=False),spring, VGroup(HGroup(Item('autoupdate',label='Auto?'), Item('updatemodelplot',show_label=False,enabled_when='not autoupdate')), Item('nmod',label='Nmodel'), HGroup(Item('datasymb',show_label=False),Item('modline',show_label=False)),springy=False),springy=True), '_', HGroup(Item('scattertool',label='Selection Mode', editor=EnumEditor(values={None:'1:No Selection', 'clicktoggle':'3:Toggle Select', 'clicksingle':'2:Single Select', 'clickimmediate':'7:Immediate', 'lassoadd':'4:Add with Lasso', 'lassoremove':'5:Remove with Lasso', 'lassoinvert':'6:Invert with Lasso'})), Item('unselectonaction',label='Clear Selection on Action?'), Item('clearsel',show_label=False), Item('weightchangesel',show_label=False), Item('weightchangeto',label='Weight'), Item('delsel',show_label=False), ),#layout='flow'), Item('tmodel',show_label=False,style='custom',editor=InstanceEditor(kind='subpanel')) ), handler=FGHandler(), resizable=True, title='Data Fitting', buttons=['OK','Cancel'], width=700, height=900 ) def __init__(self,xdata=None,ydata=None,weights=None,model=None, include_models=None,exclude_models=None,fittype=None,**traits): """ :param xdata: the first dimension of the data to be fit :type xdata: array-like :param ydata: the second dimension of the data to be fit :type ydata: array-like :param weights: The weights to apply to the data. Statistically interpreted as inverse errors (*not* inverse variance). May be any of the following forms: * None for equal weights * an array of points that must match `ydata` * a 2-sequence of arrays (xierr,yierr) such that xierr matches the `xdata` and yierr matches `ydata` * a function called as f(params) that returns an array of weights that match one of the above two conditions :param model: the initial model to use to fit this data :type model: None, string, or :class:`pymodelfit.core.FunctionModel1D` instance. :param include_models: With `exclude_models`, specifies which models should be available in the "new model" dialog (see `models.list_models` for syntax). :param exclude_models: With `include_models`, specifies which models should be available in the "new model" dialog (see `models.list_models` for syntax). :param fittype: The fitting technique for the initial fit (see :class:`pymodelfit.core.FunctionModel`). :type fittype: string kwargs are passed in as any additional traits to apply to the application. """ self.modelpanel = View(Label('empty'),kind='subpanel',title='model editor') self.tmodel = TraitedModel(model) if model is not None and fittype is not None: self.tmodel.model.fittype = fittype if xdata is None or ydata is None: if not hasattr(self.tmodel.model,'data') or self.tmodel.model.data is None: raise ValueError('data not provided and no data in model') if xdata is None: xdata = self.tmodel.model.data[0] if ydata is None: ydata = self.tmodel.model.data[1] if weights is None: weights = self.tmodel.model.data[2] self.on_trait_change(self._paramsChanged,'tmodel.paramchange') self.modelselector = NewModelSelector(include_models,exclude_models) self.data = [xdata,ydata] if weights is None: self.weights = np.ones_like(xdata) self.weighttype = 'equal' else: self.weights = np.array(weights,copy=True) self.savews = True weights1d = self.weights while len(weights1d.shape)>1: weights1d = np.sum(weights1d**2,axis=0) pd = ArrayPlotData(xdata=self.data[0],ydata=self.data[1],weights=weights1d) self.plot = plot = Plot(pd,resizable='hv') self.scatter = plot.plot(('xdata','ydata','weights'),name='data', color_mapper=_cmapblack if self.weights0rem else _cmap, type='cmap_scatter', marker='circle')[0] self.errorplots = None if not isinstance(model,FunctionModel1D): self.fitmodel = True self.updatemodelplot = False #force plot update - generates xmod and ymod plot.plot(('xmod','ymod'),name='model',type='line',line_style='dash',color='black',line_width=2) del plot.x_mapper.range.sources[-1] #remove the line plot from the x_mapper source so only the data is tied to the scaling self.on_trait_change(self._rangeChanged,'plot.index_mapper.range.updated') self.pantool = PanTool(plot,drag_button='left') plot.tools.append(self.pantool) self.zoomtool = ZoomTool(plot) self.zoomtool.prev_state_key = KeySpec('a') self.zoomtool.next_state_key = KeySpec('s') plot.overlays.append(self.zoomtool) self.scattertool = None self.scatter.overlays.append(ScatterInspectorOverlay(self.scatter, hover_color = "black", selection_color="black", selection_outline_color="red", selection_line_width=2)) self.colorbar = colorbar = ColorBar(index_mapper=LinearMapper(range=plot.color_mapper.range), color_mapper=plot.color_mapper.range, plot=plot, orientation='v', resizable='v', width = 30, padding = 5) colorbar.padding_top = plot.padding_top colorbar.padding_bottom = plot.padding_bottom colorbar._axis.title = 'Weights' self.plotcontainer = container = HPlotContainer(use_backbuffer=True) container.add(plot) container.add(colorbar) super(FitGui,self).__init__(**traits) self.on_trait_change(self._scale_change,'plot.value_scale,plot.index_scale') if weights is not None and len(weights)==2: self.weightsChanged() #update error bars def _weights0rem_changed(self,old,new): if new: self.plot.color_mapper = _cmapblack(self.plot.color_mapper.range) else: self.plot.color_mapper = _cmap(self.plot.color_mapper.range) self.plot.request_redraw() # if old and self.filloverlay in self.plot.overlays: # self.plot.overlays.remove(self.filloverlay) # if new: # self.plot.overlays.append(self.filloverlay) # self.plot.request_redraw() def _paramsChanged(self): self.updatemodelplot = True def _nmod_changed(self): self.updatemodelplot = True def _rangeChanged(self): self.updatemodelplot = True #@on_trait_change('object.plot.value_scale,object.plot.index_scale',post_init=True) def _scale_change(self): self.plot.request_redraw() def _updatemodelplot_fired(self,new): #If the plot has not been generated yet, just skip the update if self.plot is None: return #if False (e.g. button click), update regardless, otherwise check for autoupdate if new and not self.autoupdate: return mod = self.tmodel.model if self.ytype == 'data and model': if mod: #xd = self.data[0] #xmod = np.linspace(np.min(xd),np.max(xd),self.nmod) xl = self.plot.index_range.low xh = self.plot.index_range.high if self.plot.index_scale=="log": xmod = np.logspace(np.log10(xl),np.log10(xh),self.nmod) else: xmod = np.linspace(xl,xh,self.nmod) ymod = self.tmodel.model(xmod) self.plot.data.set_data('xmod',xmod) self.plot.data.set_data('ymod',ymod) else: self.plot.data.set_data('xmod',[]) self.plot.data.set_data('ymod',[]) elif self.ytype == 'residuals': if mod: self.plot.data.set_data('xmod',[]) self.plot.data.set_data('ymod',[]) #residuals set the ydata instead of setting the model res = mod.residuals(*self.data) self.plot.data.set_data('ydata',res) else: self.ytype = 'data and model' else: assert True,'invalid Enum' def _fitmodel_fired(self): from warnings import warn preaup = self.autoupdate try: self.autoupdate = False xd,yd = self.data kwd = {'x':xd,'y':yd} if self.weights is not None: w = self.weights if self.weights0rem: if xd.shape == w.shape: m = w!=0 w = w[m] kwd['x'] = kwd['x'][m] kwd['y'] = kwd['y'][m] elif np.any(w==0): warn("can't remove 0-weighted points if weights don't match data") kwd['weights'] = w self.tmodel.fitdata = kwd finally: self.autoupdate = preaup self.updatemodelplot = True self.updatestats = True # def _tmodel_changed(self,old,new): # #old is only None before it is initialized # if new is not None and old is not None and new.model is not None: # self.fitmodel = True def _newmodel_fired(self,newval): from inspect import isclass if isinstance(newval,basestring) or isinstance(newval,FunctionModel1D) \ or (isclass(newval) and issubclass(newval,FunctionModel1D)): self.tmodel = TraitedModel(newval) else: if self.modelselector.edit_traits(kind='modal').result: cls = self.modelselector.selectedmodelclass if cls is None: self.tmodel = TraitedModel(None) elif self.modelselector.isvarargmodel: self.tmodel = TraitedModel(cls(self.modelselector.modelargnum)) self.fitmodel = True else: self.tmodel = TraitedModel(cls()) self.fitmodel = True else: #cancelled return def _showerror_fired(self,evt): if self.tmodel.lastfitfailure: ex = self.tmodel.lastfitfailure dialog = HasTraits(s=ex.__class__.__name__+': '+str(ex)) view = View(Item('s',style='custom',show_label=False), resizable=True,buttons=['OK'],title='Fitting error message') dialog.edit_traits(view=view) @cached_property def _get_chi2(self): try: return self.tmodel.model.chi2Data()[0] except: return 0 @cached_property def _get_chi2r(self): try: return self.tmodel.model.chi2Data()[1] except: return 0 def _get_nomodel(self): return self.tmodel.model is None def _get_weightsvary(self): w = self.weights return np.any(w!=w[0])if len(w)>0 else False def _get_plotname(self): xlabel = self.plot.x_axis.title ylabel = self.plot.y_axis.title if xlabel == '' and ylabel == '': return '' else: return xlabel+' vs '+ylabel def _set_plotname(self,val): if isinstance(val,basestring): val = val.split('vs') if len(val) ==1: val = val.split('-') val = [v.strip() for v in val] self.x_axis.title = val[0] self.y_axis.title = val[1] #selection-related def _scattertool_changed(self,old,new): if new == 'No Selection': self.plot.tools[0].drag_button='left' else: self.plot.tools[0].drag_button='right' if old is not None and 'lasso' in old: if new is not None and 'lasso' in new: #connect correct callbacks self.lassomode = new.replace('lasso','') return else: #TODO:test self.scatter.tools[-1].on_trait_change(self._lasso_handler, 'selection_changed',remove=True) del self.scatter.overlays[-1] del self.lassomode elif old == 'clickimmediate': self.scatter.index.on_trait_change(self._immediate_handler, 'metadata_changed',remove=True) self.scatter.tools = [] if new is None: pass elif 'click' in new: smodemap = {'clickimmediate':'single','clicksingle':'single', 'clicktoggle':'toggle'} self.scatter.tools.append(ScatterInspector(self.scatter, selection_mode=smodemap[new])) if new == 'clickimmediate': self.clearsel = True self.scatter.index.on_trait_change(self._immediate_handler, 'metadata_changed') elif 'lasso' in new: lasso_selection = LassoSelection(component=self.scatter, selection_datasource=self.scatter.index) self.scatter.tools.append(lasso_selection) lasso_overlay = LassoOverlay(lasso_selection=lasso_selection, component=self.scatter) self.scatter.overlays.append(lasso_overlay) self.lassomode = new.replace('lasso','') lasso_selection.on_trait_change(self._lasso_handler, 'selection_changed') lasso_selection.on_trait_change(self._lasso_handler, 'selection_completed') lasso_selection.on_trait_change(self._lasso_handler, 'updated') else: raise TraitsError('invalid scattertool value') def _weightchangesel_fired(self): self.weights[self.selectedi] = self.weightchangeto if self.unselectonaction: self.clearsel = True self._sel_alter_weights() self.lastselaction = 'weightchangesel' def _delsel_fired(self): self.weights[self.selectedi] = 0 if self.unselectonaction: self.clearsel = True self._sel_alter_weights() self.lastselaction = 'delsel' def _sel_alter_weights(self): if self.weighttype != 'custom': self._customweights = self.weights self.weighttype = 'custom' self.weightsChanged() def _clearsel_fired(self,event): if isinstance(event,list): self.scatter.index.metadata['selections'] = event else: self.scatter.index.metadata['selections'] = list() def _lasso_handler(self,name,new): if name == 'selection_changed': lassomask = self.scatter.index.metadata['selection'].astype(int) clickmask = np.zeros_like(lassomask) clickmask[self.scatter.index.metadata['selections']] = 1 if self.lassomode == 'add': mask = clickmask | lassomask elif self.lassomode == 'remove': mask = clickmask & ~lassomask elif self.lassomode == 'invert': mask = np.logical_xor(clickmask,lassomask) else: raise TraitsError('lassomode is in invalid state') self.scatter.index.metadata['selections'] = list(np.where(mask)[0]) elif name == 'selection_completed': self.scatter.overlays[-1].visible = False elif name == 'updated': self.scatter.overlays[-1].visible = True else: raise ValueError('traits event name %s invalid'%name) def _immediate_handler(self): sel = self.selectedi if len(sel) > 1: self.clearsel = True raise TraitsError('selection error in immediate mode - more than 1 selection') elif len(sel)==1: if self.lastselaction != 'None': setattr(self,self.lastselaction,True) del sel[0] def _savews_fired(self): self._savedws = self.weights.copy() def _loadws_fired(self): self.weights = self._savedws self._savews_fired() def _get_selectedi(self): return self.scatter.index.metadata['selections'] @on_trait_change('data,ytype',post_init=True) def dataChanged(self): """ Updates the application state if the fit data are altered - the GUI will know if you give it a new data array, but not if the data is changed in-place. """ pd = self.plot.data #TODO:make set_data apply to both simultaneously? pd.set_data('xdata',self.data[0]) pd.set_data('ydata',self.data[1]) self.updatemodelplot = False @on_trait_change('weights',post_init=True) def weightsChanged(self): """ Updates the application state if the weights/error bars for this model are changed - the GUI will automatically do this if you give it a new set of weights array, but not if they are changed in-place. """ weights = self.weights if 'errorplots' in self.trait_names(): #TODO:switch this to updating error bar data/visibility changing if self.errorplots is not None: self.plot.remove(self.errorplots[0]) self.plot.remove(self.errorplots[1]) self.errorbarplots = None if len(weights.shape)==2 and weights.shape[0]==2: xerr,yerr = 1/weights high = ArrayDataSource(self.scatter.index.get_data()+xerr) low = ArrayDataSource(self.scatter.index.get_data()-xerr) ebpx = ErrorBarPlot(orientation='v', value_high = high, value_low = low, index = self.scatter.value, value = self.scatter.index, index_mapper = self.scatter.value_mapper, value_mapper = self.scatter.index_mapper ) self.plot.add(ebpx) high = ArrayDataSource(self.scatter.value.get_data()+yerr) low = ArrayDataSource(self.scatter.value.get_data()-yerr) ebpy = ErrorBarPlot(value_high = high, value_low = low, index = self.scatter.index, value = self.scatter.value, index_mapper = self.scatter.index_mapper, value_mapper = self.scatter.value_mapper ) self.plot.add(ebpy) self.errorplots = (ebpx,ebpy) while len(weights.shape)>1: weights = np.sum(weights**2,axis=0) self.plot.data.set_data('weights',weights) self.plot.plots['data'][0].color_mapper.range.refresh() if self.weightsvary: if self.colorbar not in self.plotcontainer.components: self.plotcontainer.add(self.colorbar) self.plotcontainer.request_redraw() elif self.colorbar in self.plotcontainer.components: self.plotcontainer.remove(self.colorbar) self.plotcontainer.request_redraw() def _weighttype_changed(self, name, old, new): if old == 'custom': self._customweights = self.weights if new == 'custom': self.weights = self._customweights #if hasattr(self,'_customweights') else np.ones_like(self.data[0]) elif new == 'equal': self.weights = np.ones_like(self.data[0]) elif new == 'lin bins': self.weights = binned_weights(self.data[0],10,False) elif new == 'log bins': self.weights = binned_weights(self.data[0],10,True) else: raise TraitError('Invalid Enum value on weighttype') def getModelInitStr(self): """ Generates a python code string that can be used to generate a model with parameters matching the model in this :class:`FitGui`. :returns: initializer string """ mod = self.tmodel.model if mod is None: return 'None' else: parstrs = [] for p,v in mod.pardict.iteritems(): parstrs.append(p+'='+str(v)) if mod.__class__._pars is None: #varargs need to have the first argument give the right number varcount = len(mod.params)-len(mod.__class__._statargs) parstrs.insert(0,str(varcount)) return '%s(%s)'%(mod.__class__.__name__,','.join(parstrs)) def getModelObject(self): """ Gets the underlying object representing the model for this fit. :returns: The :class:`pymodelfit.core.FunctionModel1D` object. """ return self.tmodel.model