class InternetExplorerDemo(HasTraits): # A URL to display: url = Str('http://') # The list of web pages being browsed: pages = List(WebPage) # The view to display: view = View( VGroup( Item('url', label='Location', editor=TextEditor(auto_set=False, enter_set=True))), Item('pages', show_label=False, style='custom', editor=ListEditor(use_notebook=True, deletable=True, dock_style='tab', export='DockWindowShell', page_name='.title'))) # Event handlers: def _url_changed(self, url): self.pages.append(WebPage(url=url.strip()))
class FileTreeDemo ( HasTraits ): # The path to the file tree root: root_path = Directory( entries = 10 ) # The root of the file tree: root = Property # The traits view to display: view = View( VGroup( Item( 'root_path' ), Item( 'root', editor = TreeEditor( editable = False, auto_open = 1 ) ), show_labels = False ), width = 0.33, height = 0.50, resizable = True ) #-- Traits Default Value Methods ------------------------------------------- def _root_path_default ( self ): return getcwd() #-- Property Implementations ----------------------------------------------- @property_depends_on( 'root_path' ) def _get_root ( self ): return File( path = self.root_path )
class Employee(HasTraits): # Define the traits: name = Str dept = Str email = Str # Define the view: view = View( VGroup( VGroup( Item('name', show_label=False, editor=ImageEditor(image=ImageResource( 'info', search_path=search_path)))), VGroup(Item('name'), Item('dept'), Item('email'))))
class Converter ( HasStrictTraits ): # Trait definitions: input_amount = CFloat( 12.0, desc = "the input quantity" ) input_units = Units( 'inches', desc = "the input quantity's units" ) output_amount = Property( depends_on = [ 'input_amount', 'input_units', 'output_units' ], desc = "the output quantity" ) output_units = Units( 'feet', desc = "the output quantity's units" ) # User interface views: traits_view = View( VGroup( HGroup( Item( 'input_amount', springy = True ), Item( 'input_units', show_label = False ), label = 'Input', show_border = True ), HGroup( Item( 'output_amount', style = 'readonly', springy = True ), Item( 'output_units', show_label = False ), label = 'Output', show_border = True ), help = ViewHelp ), title = 'Units Converter', buttons = [ 'Undo', 'OK', 'Help' ] ) # Property implementations def _get_output_amount ( self ): return ((self.input_amount * self.input_units_) / self.output_units_)
class SpringDemo(HasTraits): ignore = Button('Ignore') view = View(VGroup( HGroup(button, spring, button, show_border=True, label='Left and right justified'), HGroup(button, button, spring, button, button, spring, button, button, show_border=True, label='Left, center and right justified'), HGroup(spring, button, button, show_border=True, label='Right justified'), HGroup(button, button, show_border=True, label='Left justified (no springs)'), ), title='Spring Demo', buttons=['OK'])
class LineCountInfo(MFileDialogModel): """ Defines a file dialog extension that displays the number of text lines in the currently selected file. """ # The number of text lines in the currently selected file: lines = Property(depends_on='file_name') #-- Traits View Definitions ------------------------------------------------ view = View( VGroup(Item('lines', style='readonly'), label='Line Count Info', show_border=True)) #-- Property Implementations ----------------------------------------------- @cached_property def _get_lines(self): try: if getsize(self.file_name) > 10000000: return 'File too big...' fh = file(self.file_name, 'rb') data = fh.read() fh.close() except: return '' if (data.find('\x00') >= 0) or (data.find('\xFF') >= 0): return 'File contains binary data...' return ('%s lines' % commatize(len(data.splitlines())))
class EventTableManager(HasTraits): """ Manage the EventTables used for the display. Also controls where new marks are saved to. """ evt_filenames = List(EventTableListEntry) evts = List(Instance(eegpy.EventTable)) view = View( VGroup( Group( Item( 'evt_filenames', editor=tabular_editor, show_label=False, ), label="EventTables", ), ), ) def append(self, et_fn): for etle in self.evt_filenames: if etle.fn == et_fn: return False self.evt_filenames.append(EventTableListEntry(fn=et_fn)) #self.evt_filenames.sort(cmp=lambda x,y: cmp(x.short_fn,y.short_fn)) self.evts.append(eegpy.EventTable(str(et_fn))) def get_marks(self, start, stop): rv = [] for et in self.evts: for k in et.keys(): for t in et[k]: if t >= start and t <= stop: rv.append((k, t)) return rv
class MyViewController(Controller): """ Define a combined controller/view class that validates that MyModel.name is consistent with the 'allow_empty_string' flag. """ # When False, the model.name trait is not allowed to be empty: allow_empty_string = Bool # Last attempted value of model.name to be set by user: last_name = Str # Define the view associated with this controller: view = View( VGroup( HGroup(Item('name', springy=True), '10', Item('controller.allow_empty_string', label='Allow Empty')), # Add an empty vertical group so the above items don't end up # centered vertically: VGroup()), resizable=True) #-- Handler Interface ------------------------------------------------------ def name_setattr(self, info, object, name, value): """ Validate the request to change the named trait on object to the specified value. Vaildation errors raise TraitError. """ self.last_name = value if (not self.allow_empty_string) and (value.strip() == ''): raise TraitError('Empty string not allowed.') return super(MyViewController, self).setattr(info, object, name, value) #-- Event handlers --------------------------------------------------------- def controller_allow_empty_string_changed(self, info): """ 'allow_empty_string' has changed, check the name trait to ensure that it is consistent with the current setting. """ if (not self.allow_empty_string) and (self.model.name == ''): self.model.name = '?' else: self.model.name = self.last_name
class FloodFillDemo(HasTraits): lo_diff = Array(np.float, (1, 4)) hi_diff = Array(np.float, (1, 4)) plot = Instance(Plot) point = Tuple((0, 0)) option = Trait(u"以邻点为标准-4联通", Options) view = View(VGroup( VGroup(Item("lo_diff", label=u"负方向范围"), Item("hi_diff", label=u"正方向范围"), Item("option", label=u"算法标志")), Item("plot", editor=ComponentEditor(), show_label=False), ), title=u"FloodFill Demo控制面板", width=500, height=450, resizable=True) def __init__(self, *args, **kwargs): self.lo_diff.fill(5) self.hi_diff.fill(5) self.img = cv.imread("lena.jpg") self.data = ArrayPlotData(img=self.img[:, :, ::-1]) w = self.img.size().width h = self.img.size().height self.plot = Plot(self.data, padding=10, aspect_ratio=float(w) / h) self.plot.x_axis.visible = False self.plot.y_axis.visible = False self.imgplot = self.plot.img_plot("img", origin="top left")[0] self.imgplot.interpolation = "nearest" self.imgplot.overlays.append( PointPicker(application=self, component=self.imgplot)) self.on_trait_change(self.redraw, "point,lo_diff,hi_diff,option") def redraw(self): img = self.img.clone() cv.floodFill(img, cv.Point(*self.point), cv.Scalar(255, 0, 0, 255), loDiff=cv.asScalar(self.lo_diff[0]), upDiff=cv.asScalar(self.hi_diff[0]), flags=self.option_) self.data["img"] = img[:, :, ::-1]
class System(HasTraits): # The mass of the system: mass = Range(0.0, 100.0) # The velocity of the system: velocity = Range(0.0, 100.0) # The kinetic energy of the system: kinetic_energy = Property(Float) # The current error status of the system: error = Property( Bool, sync_to_view='mass.invalid, velocity.invalid, status.invalid') # The current status of the system: status = Property(Str) view = View( VGroup( VGroup(Item('mass'), Item('velocity'), Item('kinetic_energy', style='readonly', format_str='%.0f'), label='System', show_border=True), VGroup(Item('status', style='readonly', show_label=False), label='Status', show_border=True), )) @property_depends_on('mass, velocity') def _get_kinetic_energy(self): return (self.mass * self.velocity * self.velocity) / 2.0 @property_depends_on('kinetic_energy') def _get_error(self): return (self.kinetic_energy > 50000.0) @property_depends_on('error') def _get_status(self): if self.error: return 'The kinetic energy of the system is too high.' return ''
def default_traits_view(self): ''' Generates the view from the param items. ''' #rf_param_items = [ Item( 'model.' + name, format_str = '%g' ) for name in self.model.param_keys ] plot_param_items = [ Item('max_x', label='max x value'), Item('n_points', label='No of plot points') ] control_items = [ Item('show', show_label=False), Item('clear', show_label=False), ] view = View( HSplit( VGroup( Item('@resp_func', show_label=False), #*rf_param_items, label='Function Parameters', id='stats.spirrid_bak.rf_model_view.rf_params', scrollable=True), VGroup(*plot_param_items, label='Plot Parameters', id='stats.spirrid_bak.rf_model_view.plot_params'), VGroup( Item('model.comment', show_label=False, style='readonly'), label='Comment', id='stats.spirrid_bak.rf_model_view.comment', scrollable=True, ), VGroup(HGroup(*control_items), Item('figure', editor=MPLFigureEditor(), resizable=True, show_label=False), label='Plot', id='stats.spirrid_bak.rf_model_view.plot'), dock='tab', id='stats.spirrid_bak.rf_model_view.split'), kind='modal', resizable=True, dock='tab', buttons=[OKButton], id='stats.spirrid_bak.rf_model_view') return view
class Model(HasTraits): a = Code("print 'hello'") b = Button("click me") traits_view = View(HSplit( VGroup( Tabbed( Item('a'), Item('a'), Item('a')), Item('b')), VSplit( VGroup('b','b','b'), HGroup('a', show_border=True, label="traits is great")), dock="horizontal" ), resizable=True, id="my.test.program.id")
class Shape(HasTraits): shape_type = Enum("rectangle", "circle") editable = Bool x, y, w, h, r = [Int] * 5 view = View(VGroup( HGroup(Item("shape_type"), Item("editable")), VGroup(Item("x"), Item("y"), Item("w"), Item("h"), visible_when="shape_type=='rectangle'", enabled_when="editable"), VGroup(Item("x"), Item("y"), Item("r"), visible_when="shape_type=='circle'", enabled_when="editable"), ), resizable=True)
class Circle(Shape): center = Instance(Point, ()) r = Int view = View(VGroup( Item("center", style="custom"), Item("r"), )) @on_trait_change("r") def set_info(self): self.info = "area:%f" % (pi * self.r**2)
class MyDemo(HasTraits): scene = Instance(SceneModel, ()) source = Instance(tvtk.ParametricFunctionSource, ()) func_name = Enum([c.__name__ for c in source_types]) func = Property(depends_on="func_name") traits_view = View(HSplit( VGroup( Item("func_name", show_label=False), Tabbed( Item("func", style="custom", editor=InstanceEditor(), show_label=False), Item("source", style="custom", show_label=False))), Item("scene", style="custom", show_label=False, editor=SceneEditor())), resizable=True, width=700, height=600) def __init__(self, *args, **kwds): super(MyDemo, self).__init__(*args, **kwds) self._make_pipeline() def _get_func(self): return sources[self.func_name] def _make_pipeline(self): self.func.on_trait_change(self.on_change, "anytrait") src = self.source src.on_trait_change(self.on_change, "anytrait") src.parametric_function = self.func map = tvtk.PolyDataMapper(input_connection=src.output_port) act = tvtk.Actor(mapper=map) self.scene.add_actor(act) self.src = src def _func_changed(self, old_func, this_func): if old_func is not None: old_func.on_trait_change(self.on_change, "anytrait", remove=True) this_func.on_trait_change(self.on_change, "anytrait") self.src.parametric_function = this_func self.scene.render() def on_change(self): self.scene.render()
class DoublePendulumGUI(HasTraits): pendulum = Instance(DoublePendulum) m1 = Range(1.0, 10.0, 2.0) m2 = Range(1.0, 10.0, 2.0) l1 = Range(1.0, 10.0, 2.0) l2 = Range(1.0, 10.0, 2.0) positions = Tuple index = Int(0) timer = Instance(Timer) graph = Instance(DoublePendulumComponent) animation = Bool(True) view = View(HGroup( VGroup( Item("m1"), Item("m2"), Item("l1"), Item("l2"), ), Item("graph", editor=ComponentEditor(), show_label=False), ), width=600, height=400, title="双摆演示", resizable=True) def __init__(self): self.pendulum = DoublePendulum(self.m1, self.m2, self.l1, self.l2) self.pendulum.init_status[:] = 1.0, 2.0, 0, 0 self.graph = DoublePendulumComponent() self.graph.gui = self self.timer = Timer(10, self.on_timer) def on_timer(self, *args): if len(self.positions) == 0 or self.index == len(self.positions[0]): self.pendulum.m1 = self.m1 self.pendulum.m2 = self.m2 self.pendulum.l1 = self.l1 self.pendulum.l2 = self.l2 if self.animation: self.positions = double_pendulum_odeint( self.pendulum, 0, 0.5, 0.02) else: self.positions = double_pendulum_odeint( self.pendulum, 0, 0.00001, 0.00001) self.index = 0 self.graph.p = tuple(array[self.index] for array in self.positions) self.index += 1 self.graph.request_redraw()
class Factors(HasTraits): # The maximum number to include in the table: max_n = Range(1, 1000, 20, mode='slider') # The list of Factor objects: factors = Property(List) # The view of the list of Factor objects: view = View(VGroup( VGroup(Item('max_n'), show_labels=False, show_border=True, label='Maximum Number'), VGroup(Item('factors', show_label=False, editor=factors_table_editor), )), title='List of numbers and their factors', width=0.2, height=0.4, resizable=True) @property_depends_on('max_n') def _get_factors(self): return [Factor(n=i + 1) for i in xrange(self.max_n)]
def default_traits_view(self): view = View( VGroup( HGroup( Item("current_map", label=u"颜色映射", editor=EnumEditor(name="object.color_maps")), Item("reverse_map", label=u"反转颜色"), Item("position", label=u"位置", style="readonly"), ), Item("plot", show_label=False, editor=ComponentEditor()), ), resizable = True, width = 550, height = 300, title = u"Mandelbrot观察器" ) return view
class CSVGrapher(HasTraits): """ 主界面包括绘图列表,数据源,文件选择器和添加绘图按钮 """ graph_list = List(Instance(Graph)) # 绘图列表 data_source = Instance(DataSource) # 数据源 csv_file_name = File(filter=[u"*.csv"]) # 文件选择 add_graph_button = Button(u"添加绘图") # 添加绘图按钮 view = View( # 整个窗口分为上下两个部分 VGroup( # 上部分横向放置控件,因此用HGroup HGroup( # 文件选择控件 Item("csv_file_name", label=u"选择CSV文件", width=400), # 添加绘图按钮 Item("add_graph_button", show_label=False) ), # 下部分是绘图列表,采用ListEditor编辑器显示 Item("graph_list", style="custom", show_label=False, editor=ListEditor( use_notebook=True, # 是用多标签页格式显示 deletable=True, # 可以删除标签页 dock_style="tab", # 标签dock样式 page_name=".name") # 标题页的文本使用Graph对象的name属性 ) ), resizable = True, height = 0.8, width = 0.8, title = u"CSV数据绘图器" ) def _csv_file_name_changed(self): """ 打开新文件时的处理,根据文件创建一个DataSource """ self.data_source = DataSource() self.data_source.load_csv(self.csv_file_name) del self.graph_list[:] def _add_graph_button_changed(self): """ 添加绘图按钮的事件处理 """ if self.data_source != None: self.graph_list.append( Graph(data_source = self.data_source) )
def default_traits_view(self): view = View(HGroup( VGroup('renormalized', Item('data_fig', style='custom', show_label=False), 'cr_fig', 'corr_fig'), Item('usable_data', style='custom', show_label=False, editor=CheckListEditor(values=list( map(str, self.possible_usable_data)), cols=1)), ), height=800, width=800, handler=DLS_DataHandler) return view
class Doc(HasTraits): filename = File TDocStd = Instance(TDocStd.TDocStd_Document) root_label = Instance(Label) traits_view = View(VGroup(Item("filename")), Item("root_label", editor=tree_ed, show_label=False)) def _TDocStd_changed(self, new_doc): root_label = new_doc.Main().Root() label = Label(TDF_Label=root_label) self.root_label = label print "root label entry", label.entry h_u = TNaming.Handle_TNaming_UsedShapes() gid = h_u.GetObject().getid() if root_label.FindAttribute(gid, h_u): print "got used shapes"
class TiffFileInfo(MFileDialogModel): description = Property(depends_on = 'file_name') #preview = Property (depends_on = 'file_name') #kind = data_source_kinds is_ok = Bool (False) traits_view = View(VGroup( #Tabbed( Item ('description', style='readonly', show_label = False, resizable=True), #Item ('preview', style='readonly', show_label = False, resizable=True), # scrollable=True, # ), #Item('kind', label='Open as', style='custom'), ), resizable=True) @cached_property def _get_description(self): self.is_ok = False if not os.path.isfile (self.file_name): if os.path.exists (self.file_name): if os.path.isdir (self.file_name): files = [] for ext in ['tif', 'lsm']: files += glob.glob(self.file_name+'/*.'+ext) n = len (self.file_name) files = sorted([f[n+1:] for f in files]) return 'Directory contains:\n%s' % ('\n'.join (files)) return 'not a file' return 'file does not exists' if os.path.basename(self.file_name)=='configuration.txt': return unicode(open(self.file_name).read(), errors='ignore') raise NotImplementedError('opening configuration.txt data') try: tiff = TIFFfile(self.file_name, verbose=True) except ValueError, msg: return 'not a TIFF file\n%s' % (msg) self.is_ok = True try: r = tiff.get_info() except Exception, msg: r = 'failed to get TIFF info: %s' % (msg)
class HelpDialog(HasTraits): sections = Enum('About', desc.keys()) stagedescription = Str(desc['About']) view = View(Item(name='sections', show_label=False), VGroup(Item(name='stagedescription', style='readonly', show_label=False), show_border=True), title='Connectome Mapper Help', buttons=['OK'], resizable=True, width=0.4, height=0.6) def _sections_changed(self, value): self.stagedescription = desc[value]
class Triangle(Shape): a = Instance(Point, ()) b = Instance(Point, ()) c = Instance(Point, ()) view = View( VGroup( Item("a", style="custom"), Item("b", style="custom"), Item("c", style="custom"), )) @on_trait_change("a.[x,y],b.[x,y],c.[x,y]") def set_info(self): a, b, c = self.a, self.b, self.c l1 = ((a.x - b.x)**2 + (a.y - b.y)**2)**0.5 l2 = ((c.x - b.x)**2 + (c.y - b.y)**2)**0.5 l3 = ((a.x - c.x)**2 + (a.y - c.y)**2)**0.5 self.info = "edge length: %f, %f, %f" % (l1, l2, l3)
class AnimatedGIFDemo(HasTraits): # The animated GIF file to display: gif_file = File(files[0]) # Is the animation playing or not? playing = Bool(True) # The traits view: view = View(VGroup( HGroup( Item('gif_file', editor=AnimatedGIFEditor(playing='playing'), show_label=False), Item('playing'), ), '_', Item('gif_file', label='GIF File', editor=EnumEditor(values=files))), title='Animated GIF Demo', buttons=['OK'])
class OutputSelector(HasTraits): outputs = List main_plot = Instance(VariableMeshPannerView) main_panner = Property(depends_on="ds") weight_field = Str("Density") ds = Any source = Any axis = Int(0) traits_view = View(VGroup( Item('output'), Item('main_plot'), )) def __init__(self, **kwargs): super(OutputSelector, self).__init__(**kwargs) self.add_trait("output", Enum(*self.outputs)) self.output = self.outputs[-1] self.main_plot def _output_default(self): return self.outputs[0] def _output_changed(self, old, new): # We get a string here import yt.mods self.ds = yt.mods.load(new, dataset_type="enzo_packed_3d") self.source = yt.mods.projload(self.ds, self.axis, "Density") self.main_panner.field = self.main_plot.vm_plot.field self.main_plot.panner = self.main_plot.vm_plot.panner = \ self.main_plot.vm_plot.helper.panner = self.main_panner self.main_plot.vm_plot.field = self.main_panner.field def _main_plot_default(self): vmpv = VariableMeshPannerView(panner=self.main_panner) vmpv.vm_plot.helper.run_callbacks = True return vmpv @cached_property def _get_main_panner(self): return self.ds.image_panner(self.source, (512, 512), "Density")
class ShapeSelector(HasTraits): select = Enum(*[cls.__name__ for cls in Shape.__subclasses__()]) shape = Instance(Shape) view = View( VGroup( Item("select"), Item("shape", style="custom"), Item("object.shape.info", style="custom"), show_labels = False ), width = 350, height = 300, resizable = True ) def __init__(self, **traits): super(ShapeSelector, self).__init__(**traits) self._select_changed() def _select_changed(self): klass = [c for c in Shape.__subclasses__() if c.__name__ == self.select][0] self.shape = klass()
class DemoApp(HasTraits): plotbutton = Button("绘图") # mayavi场景 scene = Instance(MlabSceneModel, ()) view = View( VGroup( # 设置mayavi的编辑器 Item(name='scene', editor=SceneEditor(scene_class=MayaviScene), resizable=True, height=250, width=400), 'plotbutton', show_labels=False), title="在TraitsUI中嵌入Mayavi") def _plotbutton_fired(self): self.plot() def plot(self): mlab.test_mesh()
class ListEditorNotebookSelectionDemo(HasStrictTraits): #-- Trait Definitions ------------------------------------------------------ # List of people: people = List(Person) # The currently selected person: selected = Instance(Person) # The index of the currently selected person: index = Range(0, 7, mode='spinner') #-- Traits View Definitions ------------------------------------------------ traits_view = View(Item('index'), '_', VGroup( Item('people@', id='notebook', show_label=False, editor=ListEditor(use_notebook=True, deletable=False, selected='selected', export='DockWindowShell', page_name='.name'))), id='enthought.traits.ui.demo.Traits UI Demo.Advanced.' 'List_editor_notebook_selection_demo', dock='horizontal') #-- Trait Event Handlers --------------------------------------------------- def _selected_changed(self, selected): self.index = self.people.index(selected) def _index_changed(self, index): self.selected = self.people[index]
class MultiFitGui(HasTraits): """ data should be c x N where c is the number of data columns/axes and N is the number of points """ doplot3d = Bool(False) show3d = Button('Show 3D Plot') replot3d = Button('Replot 3D') scalefactor3d = Float(0) do3dscale = Bool(False) nmodel3d = Int(1024) usecolor3d = Bool(False) color3d = Color((0,0,0)) scene3d = Instance(MlabSceneModel,()) plot3daxes = Tuple(('x','y','z')) data = Array(shape=(None,None)) weights = Array(shape=(None,)) curveaxes = List(Tuple(Int,Int)) axisnames = Dict(Int,Str) invaxisnames = Property(Dict,depends_on='axisnames') fgs = List(Instance(FitGui)) traits_view = View(VGroup(Item('fgs',editor=ListEditor(use_notebook=True,page_name='.plotname'),style='custom',show_label=False), Item('show3d',show_label=False)), resizable=True,height=900,buttons=['OK','Cancel'],title='Multiple Model Data Fitters') plot3d_view = View(VGroup(Item('scene3d',editor=SceneEditor(scene_class=MayaviScene),show_label=False,resizable=True), Item('plot3daxes',editor=TupleEditor(cols=3,labels=['x','y','z']),label='Axes'), HGroup(Item('do3dscale',label='Scale by weight?'), Item('scalefactor3d',label='Point scale'), Item('nmodel3d',label='Nmodel')), HGroup(Item('usecolor3d',label='Use color?'),Item('color3d',label='Relation Color',enabled_when='usecolor3d')), Item('replot3d',show_label=False),springy=True), resizable=True,height=800,width=800,title='Multiple Model3D Plot') def __init__(self,data,names=None,models=None,weights=None,dofits=True,**traits): """ :param data: The data arrays :type data: sequence of c equal-length arrays (length N) :param names: Names :type names: sequence of strings, length c :param models: The models to fit for each pair either as strings or :class:`astroypsics.models.ParametricModel` objects. :type models: sequence of models, length c-1 :param weights: the weights for each point or None for no weights :type weights: array-like of size N or None :param dofits: If True, the data will be fit to the models when the object is created, otherwise the models will be passed in as-is (or as created). :type dofits: bool extra keyword arguments get passed in as new traits (r[finmask],m[finmask],l[finmask]),names='rh,Mh,Lh',weights=w[finmask],models=models,dofits=False) """ super(MultiFitGui,self).__init__(**traits) self._lastcurveaxes = None data = np.array(data,copy=False) if weights is None: self.weights = np.ones(data.shape[1]) else: self.weights = np.array(weights) self.data = data if data.shape[0] < 2: raise ValueError('Must have at least 2 columns') if isinstance(names,basestring): names = names.split(',') if names is None: if len(data) == 2: self.axisnames = {0:'x',1:'y'} elif len(data) == 3: self.axisnames = {0:'x',1:'y',2:'z'} else: self.axisnames = dict((i,str(i)) for i in data) elif len(names) == len(data): self.axisnames = dict([t for t in enumerate(names)]) else: raise ValueError("names don't match data") #default to using 0th axis as parametric self.curveaxes = [(0,i) for i in range(len(data))[1:]] if models is not None: if len(models) != len(data)-1: raise ValueError("models don't match data") for i,m in enumerate(models): fg = self.fgs[i] newtmodel = TraitedModel(m) if dofits: fg.tmodel = newtmodel fg.fitmodel = True #should happen automatically, but this makes sure else: oldpard = newtmodel.model.pardict fg.tmodel = newtmodel fg.tmodel .model.pardict = oldpard if dofits: fg.fitmodel = True def _data_changed(self): self.curveaxes = [(0,i) for i in range(len(self.data))[1:]] def _axisnames_changed(self): for ax,fg in zip(self.curveaxes,self.fgs): fg.plot.x_axis.title = self.axisnames[ax[0]] if ax[0] in self.axisnames else '' fg.plot.y_axis.title = self.axisnames[ax[1]] if ax[1] in self.axisnames else '' self.plot3daxes = (self.axisnames[0],self.axisnames[1],self.axisnames[2] if len(self.axisnames) > 2 else self.axisnames[1]) @on_trait_change('curveaxes[]') def _curveaxes_update(self,names,old,new): ax=[] for t in self.curveaxes: ax.append(t[0]) ax.append(t[1]) if set(ax) != set(range(len(self.data))): self.curveaxes = self._lastcurveaxes return #TOOD:check for recursion if self._lastcurveaxes is None: self.fgs = [FitGui(self.data[t[0]],self.data[t[1]],weights=self.weights) for t in self.curveaxes] for ax,fg in zip(self.curveaxes,self.fgs): fg.plot.x_axis.title = self.axisnames[ax[0]] if ax[0] in self.axisnames else '' fg.plot.y_axis.title = self.axisnames[ax[1]] if ax[1] in self.axisnames else '' else: for i,t in enumerate(self.curveaxes): if self._lastcurveaxes[i] != t: self.fgs[i] = fg = FitGui(self.data[t[0]],self.data[t[1]],weights=self.weights) ax = self.curveaxes[i] fg.plot.x_axis.title = self.axisnames[ax[0]] if ax[0] in self.axisnames else '' fg.plot.y_axis.title = self.axisnames[ax[1]] if ax[1] in self.axisnames else '' self._lastcurveaxes = self.curveaxes def _show3d_fired(self): self.edit_traits(view='plot3d_view') self.doplot3d = True self.replot3d = True def _plot3daxes_changed(self): self.replot3d = True @on_trait_change('weights',post_init=True) def weightsChanged(self): for fg in self.fgs: if fg.weighttype != 'custom': fg.weighttype = 'custom' fg.weights = self.weights @on_trait_change('data','fgs','replot3d','weights') def _do_3d(self): if self.doplot3d: M = self.scene3d.mlab try: xi = self.invaxisnames[self.plot3daxes[0]] yi = self.invaxisnames[self.plot3daxes[1]] zi = self.invaxisnames[self.plot3daxes[2]] x,y,z = self.data[xi],self.data[yi],self.data[zi] w = self.weights M.clf() if self.scalefactor3d == 0: sf = x.max()-x.min() sf *= y.max()-y.min() sf *= z.max()-z.min() sf = sf/len(x)/5 self.scalefactor3d = sf else: sf = self.scalefactor3d glyph = M.points3d(x,y,z,w,scale_factor=sf) glyph.glyph.scale_mode = 0 if self.do3dscale else 1 M.axes(xlabel=self.plot3daxes[0],ylabel=self.plot3daxes[1],zlabel=self.plot3daxes[2]) try: xs = np.linspace(np.min(x),np.max(x),self.nmodel3d) #find sequence of models to go from x to y and z ymods,zmods = [],[] for curri,mods in zip((yi,zi),(ymods,zmods)): while curri != xi: for i,(i1,i2) in enumerate(self.curveaxes): if curri==i2: curri = i1 mods.insert(0,self.fgs[i].tmodel.model) break else: raise KeyError ys = xs for m in ymods: ys = m(ys) zs = xs for m in zmods: zs = m(zs) if self.usecolor3d: c = (self.color3d[0]/255,self.color3d[1]/255,self.color3d[2]/255) M.plot3d(xs,ys,zs,color=c) else: M.plot3d(xs,ys,zs,np.arange(len(xs))) except (KeyError,TypeError): M.text(0.5,0.75,'Underivable relation') except KeyError: M.clf() M.text(0.25,0.25,'Data problem') @cached_property def _get_invaxisnames(self): d={} for k,v in self.axisnames.iteritems(): d[v] = k return d