class CyclesPlot(HasTraits): """ Simple plotting class with some attached controls""" plot = Instance(GridContainer) traits_view = View(Item('plot', editor=ComponentEditor(), show_label=False), width=800, height=600, resizable=True, title="Business Cycles Plot") # Private Traits _file_path = Str _dates = Array _series1 = Array _series2 = Array _selected_s1 = Array _selected_s2 = Array def __init__(self): super(CyclesPlot, self).__init__() # Normally you'd pass in the data, but I'll hardwire things for this # one-off plot. srecs = read_time_series_from_csv("./biz_cycles2.csv", date_col=0, date_format="%Y-%m-%d") dt = srecs["Date"] # Industrial production compared with trend (plotted on value axis) iprod_vs_trend = srecs["Metric 1"] # Industrial production change in last 6 Months (plotted on index axis) iprod_delta = srecs["Metric 2"] self._dates = dt self._series1 = self._selected_s1 = iprod_delta self._series2 = self._selected_s2 = iprod_vs_trend end_x = np.array([self._selected_s1[-1]]) end_y = np.array([self._selected_s2[-1]]) plotdata = ArrayPlotData(x=self._series1, y=self._series2, dates=self._dates, selected_x=self._selected_s1, selected_y=self._selected_s2, endpoint_x=end_x, endpoint_y=end_y) cycles = Plot(plotdata, padding=20) cycles.plot(("x", "y"), type="line", color=(.2, .4, .5, .4)) cycles.plot(("selected_x", "selected_y"), type="line", marker="circle", line_width=3, color=(.2, .4, .5, .9)) cycles.plot(("endpoint_x", "endpoint_y"), type="scatter", marker_size=4, marker="circle", color=(.2, .4, .5, .2), outline_color=(.2, .4, .5, .6)) cycles.index_range = DataRange1D(low_setting=80., high_setting=120.) cycles.value_range = DataRange1D(low_setting=80., high_setting=120.) # dig down to use actual Plot object cyc_plot = cycles.components[0] # Add the labels in the quadrants cyc_plot.overlays.append( PlotLabel("\nSlowdown" + 40 * " " + "Expansion", component=cyc_plot, font="swiss 24", color=(.2, .4, .5, .6), overlay_position="inside top")) cyc_plot.overlays.append( PlotLabel("Downturn" + 40 * " " + "Recovery\n ", component=cyc_plot, font="swiss 24", color=(.2, .4, .5, .6), overlay_position="inside bottom")) timeline = Plot(plotdata, resizable='h', height=50, padding=20) timeline.plot(("dates", "x"), type="line", color=(.2, .4, .5, .8), name='x') timeline.plot(("dates", "y"), type="line", color=(.5, .4, .2, .8), name='y') # Snap on the tools zoomer = ZoomTool(timeline, drag_button="right", always_on=True, tool_mode="range", axis="index", max_zoom_out_factor=1.1) panner = PanTool(timeline, constrain=True, constrain_direction="x") # dig down to get Plot component I want x_plt = timeline.plots['x'][0] range_selection = RangeSelection(x_plt, left_button_selects=True) range_selection.on_trait_change(self.update_interval, 'selection') x_plt.tools.append(range_selection) x_plt.overlays.append(RangeSelectionOverlay(x_plt)) # Set the plot's bottom axis to use the Scales ticking system scale_sys = CalendarScaleSystem( fill_ratio=0.4, default_numlabels=5, default_numticks=10, ) tick_gen = ScalesTickGenerator(scale=scale_sys) bottom_axis = ScalesPlotAxis(timeline, orientation="bottom", tick_generator=tick_gen) # Hack to remove default axis - FIXME: how do I *replace* an axis? del (timeline.underlays[-2]) timeline.overlays.append(bottom_axis) container = GridContainer(padding=20, fill_padding=True, bgcolor="lightgray", use_backbuffer=True, shape=(2, 1), spacing=(30, 30)) # add a central "x" and "y" axis x_line = LineInspector(cyc_plot, is_listener=True, color="gray", width=2) y_line = LineInspector(cyc_plot, is_listener=True, color="gray", width=2, axis="value") cyc_plot.overlays.append(x_line) cyc_plot.overlays.append(y_line) cyc_plot.index.metadata["selections"] = 100.0 cyc_plot.value.metadata["selections"] = 100.0 container.add(cycles) container.add(timeline) container.title = "Business Cycles" self.plot = container def update_interval(self, value): # Reaching pretty deep here to get selections sels = self.plot.plot_components[1].plots['x'][0].index.metadata[ 'selections'] if not sels is None: p = self._dates >= sels[0] q = self._dates <= sels[1] msk = p & q self._selected_s1 = self._series1[msk] self._selected_s2 = self._series2[msk] # Find the index of the last point in the mask last_idx = -(msk[::-1].argmax() + 1) endpoint_x = np.array([self._series1[last_idx]]) endpoint_y = np.array([self._series2[last_idx]]) else: self._selected_s1 = self._series1 self._selected_s2 = self._series2 endpoint_x = np.array([self._series1[-1]]) endpoint_y = np.array([self._series2[-1]]) self.plot.plot_components[0].data['selected_x'] = self._selected_s1 self.plot.plot_components[0].data['selected_y'] = self._selected_s2 self.plot.plot_components[0].data['endpoint_x'] = endpoint_x self.plot.plot_components[0].data['endpoint_y'] = endpoint_y
class InPaintDemo(HasTraits): plot = Instance(Plot) painter = Instance(CirclePainter) r = Range(2.0, 20.0, 10.0) # inpaint的半径参数 method = Enum("INPAINT_NS", "INPAINT_TELEA") # inpaint的算法 show_mask = Bool(False) # 是否显示选区 clear_mask = Button("清除选区") apply = Button("保存结果") view = View(VGroup( VGroup( Item("object.painter.r", label="画笔半径"), Item("r", label="inpaint半径"), HGroup( Item("method", label="inpaint算法"), Item("show_mask", label="显示选区"), Item("clear_mask", show_label=False), Item("apply", show_label=False), )), Item("plot", editor=ComponentEditor(), show_label=False), ), title="inpaint Demo控制面板", width=500, height=450, resizable=True) def __init__(self, *args, **kwargs): super(InPaintDemo, self).__init__(*args, **kwargs) self.img = cv.imread("stuff.jpg") # 原始图像 self.img2 = self.img.clone() # inpaint效果预览图像 self.mask = cv.Mat(self.img.size(), cv.CV_8UC1) # 储存选区的图像 self.mask[:] = 0 self.data = ArrayPlotData(img=self.img[:, :, ::-1]) self.plot = Plot(self.data, padding=10, aspect_ratio=float(self.img.size().width) / self.img.size().height) self.plot.x_axis.visible = False self.plot.y_axis.visible = False imgplot = self.plot.img_plot("img", origin="top left")[0] self.painter = CirclePainter(component=imgplot) imgplot.overlays.append(self.painter) @on_trait_change("r,method") def inpaint(self): cv.inpaint(self.img, self.mask, self.img2, self.r, getattr(cv, self.method)) self.draw() @on_trait_change("painter:updated") def painter_updated(self): for _, _, x, y in self.painter.track: # 在储存选区的mask上绘制圆形 cv.circle(self.mask, cv.Point(int(x), int(y)), int(self.painter.r), cv.Scalar(255, 255, 255, 255), thickness=-1) # 宽度为负表示填充圆形 self.inpaint() self.painter.track = [] self.painter.request_redraw() def _clear_mask_fired(self): self.mask[:] = 0 self.inpaint() def _apply_fired(self): """保存inpaint的处理结果,并清除选区""" self.img[:] = self.img2[:] self._clear_mask_fired() @on_trait_change("show_mask") def draw(self): if self.show_mask: data = self.img[:, :, ::-1].copy() data[self.mask[:] > 0] = 255 self.data["img"] = data else: self.data["img"] = self.img2[:, :, ::-1]
def call_mlab(self, scene=None, show=True, is_3d=False, view=None, roll=None, fgcolor=(0.0, 0.0, 0.0), bgcolor=(1.0, 1.0, 1.0), layout='rowcol', scalar_mode='iso_surface', vector_mode='arrows_norm', rel_scaling=None, clamping=False, ranges=None, is_scalar_bar=False, is_wireframe=False, opacity=None, subdomains_args=None, rel_text_width=None, fig_filename='view.png', resolution=None, filter_names=None, only_names=None, group_names=None, step=None, time=None, anti_aliasing=None, domain_specific=None): """ By default, all data (point, cell, scalars, vectors, tensors) are plotted in a grid layout, except data named 'node_groups', 'mat_id' which are usually not interesting. Parameters ---------- show : bool Call mlab.show(). is_3d : bool If True, use scalar cut planes instead of surface for certain datasets. Also sets 3D view mode. view : tuple Azimuth, elevation angles, distance and focal point as in `mlab.view()`. roll : float Roll angle tuple as in mlab.roll(). fgcolor : tuple of floats (R, G, B) The foreground color, that is the color of all text annotation labels (axes, orientation axes, scalar bar labels). bgcolor : tuple of floats (R, G, B) The background color. layout : str Grid layout for placing the datasets. Possible values are: 'row', 'col', 'rowcol', 'colrow'. scalar_mode : str Mode for plotting scalars and tensor magnitudes, one of 'cut_plane', 'iso_surface', 'both'. vector_mode : str Mode for plotting vectors, one of 'arrows', 'norm', 'arrows_norm', 'warp_norm'. rel_scaling : float Relative scaling of glyphs for vector datasets. clamping : bool Clamping for vector datasets. ranges : dict List of data ranges in the form {name : (min, max), ...}. is_scalar_bar : bool If True, show a scalar bar for each data. is_wireframe : bool If True, show a wireframe of mesh surface bar for each data. opacity : float Global surface and wireframe opacity setting in [0.0, 1.0], subdomains_args : tuple Tuple of (mat_id_name, threshold_limits, single_color), see :func:`add_subdomains_surface`, or None. rel_text_width : float Relative text width. fig_filename : str File name for saving the resulting scene figure. resolution : tuple Scene and figure resolution. If None, it is set automatically according to the layout. filter_names : list of strings Omit the listed datasets. If None, it is initialized to ['node_groups', 'mat_id']. Pass [] if you need no filtering. only_names : list of strings Draw only the listed datasets. If None, it is initialized all names besides those in filter_names. group_names : list of tuples List of data names in the form [(name1, ..., nameN), (...)]. Plots of data named in each group are superimposed. Repetitions of names are possible. step : int, optional If not None, the time step to display. The closest higher step is used if the desired one is not available. Has precedence over `time`. time : float, optional If not None, the time of the time step to display. The closest higher time is used if the desired one is not available. anti_aliasing : int Value of anti-aliasing. domain_specific : dict Domain-specific drawing functions and configurations. """ self.fgcolor = fgcolor self.bgcolor = bgcolor if filter_names is None: filter_names = ['node_groups', 'mat_id'] if rel_text_width is None: rel_text_width = 0.02 if isinstance(scalar_mode, basestr): if scalar_mode == 'both': scalar_mode = ('cut_plane', 'iso_surface') elif scalar_mode in ('cut_plane', 'iso_surface'): scalar_mode = (scalar_mode, ) else: raise ValueError('bad value of scalar_mode parameter! (%s)' % scalar_mode) else: for sm in scalar_mode: if not sm in ('cut_plane', 'iso_surface'): raise ValueError( 'bad value of scalar_mode parameter! (%s)' % sm) if isinstance(vector_mode, basestr): if vector_mode == 'arrows_norm': vector_mode = ('arrows', 'norm') elif vector_mode == 'warp_norm': vector_mode = ('warp', 'norm') elif vector_mode in ('arrows', 'norm'): vector_mode = (vector_mode, ) elif vector_mode == 'cut_plane': if is_3d: vector_mode = ('cut_plane', ) else: vector_mode = ('arrows', ) else: raise ValueError('bad value of vector_mode parameter! (%s)' % vector_mode) else: for vm in vector_mode: if not vm in ('arrows', 'norm', 'warp'): raise ValueError( 'bad value of vector_mode parameter! (%s)' % vm) mlab.options.offscreen = self.offscreen self.size_hint = self.get_size_hint(layout, resolution=resolution) is_new_scene = False if scene is not None: if scene is not self.scene: is_new_scene = True self.scene = scene gui = None else: if (self.scene is not None) and (not self.scene.running): self.scene = None if self.scene is None: if self.offscreen: gui = None scene = mlab.figure(fgcolor=fgcolor, bgcolor=bgcolor, size=self.size_hint) else: gui = ViewerGUI(viewer=self, fgcolor=fgcolor, bgcolor=bgcolor) scene = gui.scene.mayavi_scene if scene is not self.scene: is_new_scene = True self.scene = scene else: gui = self.gui scene = self.scene self.engine = mlab.get_engine() self.engine.current_scene = self.scene self.gui = gui self.file_source = create_file_source(self.filename, watch=self.watch, offscreen=self.offscreen) steps, times = self.file_source.get_ts_info() has_several_times = len(times) > 0 has_several_steps = has_several_times or (len(steps) > 0) if gui is not None: gui.has_several_steps = has_several_steps self.reload_source = reload_source = ReloadSource() reload_source._viewer = self reload_source._source = self.file_source if has_several_steps: self.set_step = set_step = SetStep() set_step._viewer = self set_step._source = self.file_source if step is not None: step = step if step >= 0 else steps[-1] + step + 1 assert_(steps[0] <= step <= steps[-1], msg='invalid time step! (%d <= %d <= %d)' % (steps[0], step, steps[-1])) set_step.step = step elif time is not None: assert_(times[0] <= time <= times[-1], msg='invalid time! (%e <= %e <= %e)' % (times[0], time, times[-1])) set_step.time = time else: set_step.step = steps[0] if self.watch: self.file_source.setup_notification(set_step, 'file_changed') if gui is not None: gui.set_step = set_step else: if self.watch: self.file_source.setup_notification(reload_source, 'reload_source') self.options.update(get_arguments(omit=['self', 'file_source'])) if gui is None: self.render_scene(scene, self.options) self.reset_view() if is_scalar_bar: self.show_scalar_bars(self.scalar_bars) else: traits_view = View( Item( 'scene', editor=SceneEditor(scene_class=MayaviScene), show_label=False, width=self.size_hint[0], height=self.size_hint[1], style='custom', ), Group( Item('set_step', defined_when='set_step is not None', show_label=False, style='custom'), ), HGroup( spring, Item('button_make_snapshots_steps', show_label=False, enabled_when='has_several_steps == True'), Item('button_make_animation_steps', show_label=False, enabled_when='has_several_steps == True'), spring, Item('button_make_snapshots_times', show_label=False, enabled_when='has_several_steps == True'), Item('button_make_animation_times', show_label=False, enabled_when='has_several_steps == True'), spring, ), HGroup(spring, Item('button_reload', show_label=False), Item('button_view', show_label=False), Item('button_quit', show_label=False)), resizable=True, buttons=[], handler=ClosingHandler(), ) if is_new_scene: if show: gui.configure_traits(view=traits_view) else: gui.edit_traits(view=traits_view) return gui
# -*- coding: utf-8 -*- """ 演示如何自定义控制器类,和各种监听函数 """ from enthought.traits.api import HasTraits, Str, Int from enthought.traits.ui.api import View, Item, Group, Handler from enthought.traits.ui.menu import ModalButtons g1 = [Item('department', label=u"部门"), Item('name', label=u"姓名")] g2 = [Item('salary', label=u"工资"), Item('bonus', label=u"奖金")] class Employee(HasTraits): name = Str department = Str salary = Int bonus = Int def _department_changed(self): print self, "department changed to ", self.department def __str__(self): return "<Employee at 0x%x>" % id(self) view1 = View(Group(*g1, label=u'个人信息', show_border=True), Group(*g2, label=u'收入', show_border=True), title=u"外部视图", kind="modal", buttons=ModalButtons)
from labtools.analysis.plot import Plot from labtools.utils.data_viewer import StructArrayData import numpy as np from labtools.log import create_logger log = create_logger(__name__) class DlsError(Exception): pass dls_analyzer_group = Group( Item('filenames', show_label=False, style='custom'), 'constants', Item('process_selected_btn', show_label=False), Item('process_all_btn', show_label=False), ) class DlsFitter(DataFitter): """In adition to :class:`DataFitter` it defines :meth:`open_dls` to open dls data """ def _plotter_default(self): return Plot(xlabel='Lag time [ms]', ylabel='g2-1', xscale='log', title='g2 -1')
label='name', view=no_view), TreeNode(node_for=[ControlPoint], auto_open=False, children='', label='label', view=no_view), ], on_select=on_tree_select) # The main view view = View( Group( Item( name = 'patientlist', id = 'patientlist', editor = tree_editor, resizable = True ), orientation = 'vertical', show_labels = True, show_left = False, ), title = 'Patients', id = \ 'dicomutils.viewer.tree', dock = 'horizontal', drop_class = HasTraits, handler = TreeHandler(), buttons = [ 'Undo', 'OK', 'Cancel' ], resizable = True, width = .3, height = .3 )
class Case(HasTraits): ''' A class representing an avl input file ''' name = Str() mach_no = Float() symmetry = List(minlen=3, maxlen=3) ref_area = Float() ref_chord = Float() ref_span = Float() ref_cg = Array(numpy.float, (3, )) CD_p = Float geometry = Instance(Geometry) cwd = Directory case_filename = File('') traits_view = View(Item('name'), Item('mach_no'), Item('symmetry'), Item('ref_area'), Item('ref_chord'), Item('ref_span'), Item('ref_cg', editor=ArrayEditor()), Item('CD_p')) #@cached_property #def _get_geometries(self): # return [self.geometry] if self.geometry is not None else [] controls = DelegatesTo('geometry') def write_input_file(self, file): ''' Write all the data in the case in the appropriate format as in input .avl file for the AVL program ''' file.write(self.name + '\n') file.write('#Mach no\n%f\n' % self.mach_no) file.write('#iYsym iZsym Zsym\n%s %s %s\n' % tuple(self.symmetry)) file.write('#Sref Cref Bref\n%f %f %f\n' % (self.ref_area, self.ref_chord, self.ref_span)) file.write('#Xref Yref Zref\n%f %f %f\n' % tuple(self.ref_cg)) if self.CD_p != 0.0: file.write('#CD_p profile drag coefficient\n%f\n' % self.CD_p) file.write('\n') file.write('#' * 70) file.write('\n') self.geometry.write_to_file(file) file.write('') @classmethod def case_from_input_file(cls, file, cwd=''): ''' return an instance of Case by reading its data from an input file ''' lines = file.readlines() lines = filter_lines(lines) lineno = 0 name = lines[0] mach_no = float(lines[1].split()[0]) symmetry = lines[2].split() symmetry = [int(symmetry[0]), int(symmetry[1]), float(symmetry[2])] ref_area, ref_chord, ref_span = [ float(value) for value in lines[3].split()[:3] ] ref_cg = [float(value) for value in lines[4].split()[:3]] lineno = 5 try: CD_p = float(lines[5].split()[0]) lineno = 6 except ValueError: CD_p = 0.0 geometry = Geometry.create_from_lines(lines, lineno, cwd=cwd) case = Case(name=name, mach_no=mach_no, symmetry=symmetry, ref_area=ref_area, ref_chord=ref_chord, ref_span=ref_span, ref_cg=ref_cg, CD_p=CD_p, geometry=geometry, cwd=cwd) return case
class GenerateProjectorCalibration(HasTraits): #width = traits.Int #height = traits.Int display_id = traits.String plot = Instance(Component) linedraw = Instance(LineSegmentTool) viewport_id = traits.String('viewport_0') display_mode = traits.Trait('white on black', 'black on white') client = traits.Any blit_compressed_image_proxy = traits.Any set_display_server_mode_proxy = traits.Any traits_view = View( Group(Item('display_mode'), Item('viewport_id'), Item('plot', editor=ComponentEditor(), show_label=False), orientation="vertical"), resizable=True, ) def __init__(self, *args, **kwargs): display_coords_filename = kwargs.pop('display_coords_filename') super(GenerateProjectorCalibration, self).__init__(*args, **kwargs) fd = open(display_coords_filename, mode='r') data = pickle.load(fd) fd.close() self.param_name = 'virtual_display_config_json_string' self.fqdn = '/virtual_displays/' + self.display_id + '/' + self.viewport_id self.fqpn = self.fqdn + '/' + self.param_name self.client = dynamic_reconfigure.client.Client(self.fqdn) self._update_image() if 1: virtual_display_json_str = rospy.get_param(self.fqpn) this_virtual_display = json.loads(virtual_display_json_str) if 1: virtual_display_json_str = rospy.get_param(self.fqpn) this_virtual_display = json.loads(virtual_display_json_str) all_points_ok = True # error check for (x, y) in this_virtual_display['viewport']: if (x >= self.width) or (y >= self.height): all_points_ok = False break if all_points_ok: self.linedraw.points = this_virtual_display['viewport'] # else: # self.linedraw.points = [] self._update_image() def _update_image(self): self._image = np.zeros((self.height, self.width, 3), dtype=np.uint8) # draw polygon if len(self.linedraw.points) >= 3: pts = [(posint(y, self.height - 1), posint(x, self.width - 1)) for (x, y) in self.linedraw.points] mahotas.polygon.fill_polygon(pts, self._image[:, :, 0]) self._image[:, :, 0] *= 255 self._image[:, :, 1] = self._image[:, :, 0] self._image[:, :, 2] = self._image[:, :, 0] # draw red horizontal stripes for i in range(0, self.height, 100): self._image[i:i + 10, :, 0] = 255 # draw blue vertical stripes for i in range(0, self.width, 100): self._image[:, i:i + 10, 2] = 255 if hasattr(self, '_pd'): self._pd.set_data("imagedata", self._image) self.send_array() if len(self.linedraw.points) >= 3: self.update_ROS_params() def _plot_default(self): self._pd = ArrayPlotData() self._pd.set_data("imagedata", self._image) plot = Plot(self._pd, default_origin="top left") plot.x_axis.orientation = "top" img_plot = plot.img_plot("imagedata")[0] plot.bgcolor = "white" # Tweak some of the plot properties plot.title = "Click to add points, press Enter to clear selection" plot.padding = 50 plot.line_width = 1 # Attach some tools to the plot pan = PanTool(plot, drag_button="right", constrain_key="shift") plot.tools.append(pan) zoom = ZoomTool(component=plot, tool_mode="box", always_on=False) plot.overlays.append(zoom) return plot def _linedraw_default(self): linedraw = LineSegmentTool(self.plot, color=(0.5, 0.5, 0.9, 1.0)) self.plot.overlays.append(linedraw) linedraw.on_trait_change(self.points_changed, 'points[]') return linedraw def points_changed(self): self._update_image() @traits.on_trait_change('display_mode') def send_array(self): # create an array if self.display_mode.endswith(' on black'): bgcolor = (0, 0, 0, 1) elif self.display_mode.endswith(' on white'): bgcolor = (1, 1, 1, 1) if self.display_mode.startswith('black '): color = (0, 0, 0, 1) elif self.display_mode.startswith('white '): color = (1, 1, 1, 1) fname = tempfile.mktemp('.png') try: scipy.misc.imsave(fname, self._image) image = freemoovr.msg.FreemooVRCompressedImage() image.format = 'png' image.data = open(fname).read() self.blit_compressed_image_proxy(image) finally: os.unlink(fname) def get_viewport_verts(self): # convert to integers pts = [(posint(x, self.width - 1), posint(y, self.height - 1)) for (x, y) in self.linedraw.points] # convert to list of lists for maximal json compatibility return [list(x) for x in pts]
class FitGui(HasTraits): """ This class represents the fitgui application state. """ plot = Instance(Plot) colorbar = Instance(ColorBar) plotcontainer = Instance(HPlotContainer) tmodel = Instance(TraitedModel,allow_none=False) nomodel = Property newmodel = Button('New Model...') fitmodel = Button('Fit Model') showerror = Button('Fit Error') updatemodelplot = Button('Update Model Plot') autoupdate = Bool(True) data = Array(dtype=float,shape=(2,None)) weights = Array weighttype = Enum(('custom','equal','lin bins','log bins')) weightsvary = Property(Bool) weights0rem = Bool(True) modelselector = NewModelSelector ytype = Enum(('data and model','residuals')) zoomtool = Instance(ZoomTool) pantool = Instance(PanTool) scattertool = Enum(None,'clicktoggle','clicksingle','clickimmediate','lassoadd','lassoremove','lassoinvert') selectedi = Property #indecies of the selected objects weightchangesel = Button('Set Selection To') weightchangeto = Float(1.0) delsel = Button('Delete Selected') unselectonaction = Bool(True) clearsel = Button('Clear Selections') lastselaction = Str('None') datasymb = Button('Data Symbol...') modline = Button('Model Line...') savews = Button('Save Weights') loadws = Button('Load Weights') _savedws = Array plotname = Property updatestats = Event chi2 = Property(Float,depends_on='updatestats') chi2r = Property(Float,depends_on='updatestats') nmod = Int(1024) #modelpanel = View(Label('empty'),kind='subpanel',title='model editor') modelpanel = View panel_view = View(VGroup( Item('plot', editor=ComponentEditor(),show_label=False), HGroup(Item('tmodel.modelname',show_label=False,style='readonly'), Item('nmod',label='Number of model points'), Item('updatemodelplot',show_label=False,enabled_when='not autoupdate'), Item('autoupdate',label='Auto?')) ), title='Model Data Fitter' ) selection_view = View(Group( Item('scattertool',label='Selection Mode', editor=EnumEditor(values={None:'1:No Selection', 'clicktoggle':'3:Toggle Select', 'clicksingle':'2:Single Select', 'clickimmediate':'7:Immediate', 'lassoadd':'4:Add with Lasso', 'lassoremove':'5:Remove with Lasso', 'lassoinvert':'6:Invert with Lasso'})), Item('unselectonaction',label='Clear Selection on Action?'), Item('clearsel',show_label=False), Item('weightchangesel',show_label=False), Item('weightchangeto',label='Weight'), Item('delsel',show_label=False) ),title='Selection Options') traits_view = View(VGroup( HGroup(Item('object.plot.index_scale',label='x-scaling', enabled_when='object.plot.index_mapper.range.low>0 or object.plot.index_scale=="log"'), spring, Item('ytype',label='y-data'), Item('object.plot.value_scale',label='y-scaling', enabled_when='object.plot.value_mapper.range.low>0 or object.plot.value_scale=="log"') ), Item('plotcontainer', editor=ComponentEditor(),show_label=False), HGroup(VGroup(HGroup(Item('weighttype',label='Weights:'), Item('savews',show_label=False), Item('loadws',enabled_when='_savedws',show_label=False)), Item('weights0rem',label='Remove 0-weight points for fit?'), HGroup(Item('newmodel',show_label=False), Item('fitmodel',show_label=False), Item('showerror',show_label=False,enabled_when='tmodel.lastfitfailure'), VGroup(Item('chi2',label='Chi2:',style='readonly',format_str='%6.6g',visible_when='tmodel.model is not None'), Item('chi2r',label='reduced:',style='readonly',format_str='%6.6g',visible_when='tmodel.model is not None')) )#Item('selbutton',show_label=False)) ,springy=False),spring, VGroup(HGroup(Item('autoupdate',label='Auto?'), Item('updatemodelplot',show_label=False,enabled_when='not autoupdate')), Item('nmod',label='Nmodel'), HGroup(Item('datasymb',show_label=False),Item('modline',show_label=False)),springy=False),springy=True), '_', HGroup(Item('scattertool',label='Selection Mode', editor=EnumEditor(values={None:'1:No Selection', 'clicktoggle':'3:Toggle Select', 'clicksingle':'2:Single Select', 'clickimmediate':'7:Immediate', 'lassoadd':'4:Add with Lasso', 'lassoremove':'5:Remove with Lasso', 'lassoinvert':'6:Invert with Lasso'})), Item('unselectonaction',label='Clear Selection on Action?'), Item('clearsel',show_label=False), Item('weightchangesel',show_label=False), Item('weightchangeto',label='Weight'), Item('delsel',show_label=False), ),#layout='flow'), Item('tmodel',show_label=False,style='custom',editor=InstanceEditor(kind='subpanel')) ), handler=FGHandler(), resizable=True, title='Data Fitting', buttons=['OK','Cancel'], width=700, height=900 ) def __init__(self,xdata=None,ydata=None,weights=None,model=None, include_models=None,exclude_models=None,fittype=None,**traits): """ :param xdata: the first dimension of the data to be fit :type xdata: array-like :param ydata: the second dimension of the data to be fit :type ydata: array-like :param weights: The weights to apply to the data. Statistically interpreted as inverse errors (*not* inverse variance). May be any of the following forms: * None for equal weights * an array of points that must match `ydata` * a 2-sequence of arrays (xierr,yierr) such that xierr matches the `xdata` and yierr matches `ydata` * a function called as f(params) that returns an array of weights that match one of the above two conditions :param model: the initial model to use to fit this data :type model: None, string, or :class:`pymodelfit.core.FunctionModel1D` instance. :param include_models: With `exclude_models`, specifies which models should be available in the "new model" dialog (see `models.list_models` for syntax). :param exclude_models: With `include_models`, specifies which models should be available in the "new model" dialog (see `models.list_models` for syntax). :param fittype: The fitting technique for the initial fit (see :class:`pymodelfit.core.FunctionModel`). :type fittype: string kwargs are passed in as any additional traits to apply to the application. """ self.modelpanel = View(Label('empty'),kind='subpanel',title='model editor') self.tmodel = TraitedModel(model) if model is not None and fittype is not None: self.tmodel.model.fittype = fittype if xdata is None or ydata is None: if not hasattr(self.tmodel.model,'data') or self.tmodel.model.data is None: raise ValueError('data not provided and no data in model') if xdata is None: xdata = self.tmodel.model.data[0] if ydata is None: ydata = self.tmodel.model.data[1] if weights is None: weights = self.tmodel.model.data[2] self.on_trait_change(self._paramsChanged,'tmodel.paramchange') self.modelselector = NewModelSelector(include_models,exclude_models) self.data = [xdata,ydata] if weights is None: self.weights = np.ones_like(xdata) self.weighttype = 'equal' else: self.weights = np.array(weights,copy=True) self.savews = True weights1d = self.weights while len(weights1d.shape)>1: weights1d = np.sum(weights1d**2,axis=0) pd = ArrayPlotData(xdata=self.data[0],ydata=self.data[1],weights=weights1d) self.plot = plot = Plot(pd,resizable='hv') self.scatter = plot.plot(('xdata','ydata','weights'),name='data', color_mapper=_cmapblack if self.weights0rem else _cmap, type='cmap_scatter', marker='circle')[0] self.errorplots = None if not isinstance(model,FunctionModel1D): self.fitmodel = True self.updatemodelplot = False #force plot update - generates xmod and ymod plot.plot(('xmod','ymod'),name='model',type='line',line_style='dash',color='black',line_width=2) del plot.x_mapper.range.sources[-1] #remove the line plot from the x_mapper source so only the data is tied to the scaling self.on_trait_change(self._rangeChanged,'plot.index_mapper.range.updated') self.pantool = PanTool(plot,drag_button='left') plot.tools.append(self.pantool) self.zoomtool = ZoomTool(plot) self.zoomtool.prev_state_key = KeySpec('a') self.zoomtool.next_state_key = KeySpec('s') plot.overlays.append(self.zoomtool) self.scattertool = None self.scatter.overlays.append(ScatterInspectorOverlay(self.scatter, hover_color = "black", selection_color="black", selection_outline_color="red", selection_line_width=2)) self.colorbar = colorbar = ColorBar(index_mapper=LinearMapper(range=plot.color_mapper.range), color_mapper=plot.color_mapper.range, plot=plot, orientation='v', resizable='v', width = 30, padding = 5) colorbar.padding_top = plot.padding_top colorbar.padding_bottom = plot.padding_bottom colorbar._axis.title = 'Weights' self.plotcontainer = container = HPlotContainer(use_backbuffer=True) container.add(plot) container.add(colorbar) super(FitGui,self).__init__(**traits) self.on_trait_change(self._scale_change,'plot.value_scale,plot.index_scale') if weights is not None and len(weights)==2: self.weightsChanged() #update error bars def _weights0rem_changed(self,old,new): if new: self.plot.color_mapper = _cmapblack(self.plot.color_mapper.range) else: self.plot.color_mapper = _cmap(self.plot.color_mapper.range) self.plot.request_redraw() # if old and self.filloverlay in self.plot.overlays: # self.plot.overlays.remove(self.filloverlay) # if new: # self.plot.overlays.append(self.filloverlay) # self.plot.request_redraw() def _paramsChanged(self): self.updatemodelplot = True def _nmod_changed(self): self.updatemodelplot = True def _rangeChanged(self): self.updatemodelplot = True #@on_trait_change('object.plot.value_scale,object.plot.index_scale',post_init=True) def _scale_change(self): self.plot.request_redraw() def _updatemodelplot_fired(self,new): #If the plot has not been generated yet, just skip the update if self.plot is None: return #if False (e.g. button click), update regardless, otherwise check for autoupdate if new and not self.autoupdate: return mod = self.tmodel.model if self.ytype == 'data and model': if mod: #xd = self.data[0] #xmod = np.linspace(np.min(xd),np.max(xd),self.nmod) xl = self.plot.index_range.low xh = self.plot.index_range.high if self.plot.index_scale=="log": xmod = np.logspace(np.log10(xl),np.log10(xh),self.nmod) else: xmod = np.linspace(xl,xh,self.nmod) ymod = self.tmodel.model(xmod) self.plot.data.set_data('xmod',xmod) self.plot.data.set_data('ymod',ymod) else: self.plot.data.set_data('xmod',[]) self.plot.data.set_data('ymod',[]) elif self.ytype == 'residuals': if mod: self.plot.data.set_data('xmod',[]) self.plot.data.set_data('ymod',[]) #residuals set the ydata instead of setting the model res = mod.residuals(*self.data) self.plot.data.set_data('ydata',res) else: self.ytype = 'data and model' else: assert True,'invalid Enum' def _fitmodel_fired(self): from warnings import warn preaup = self.autoupdate try: self.autoupdate = False xd,yd = self.data kwd = {'x':xd,'y':yd} if self.weights is not None: w = self.weights if self.weights0rem: if xd.shape == w.shape: m = w!=0 w = w[m] kwd['x'] = kwd['x'][m] kwd['y'] = kwd['y'][m] elif np.any(w==0): warn("can't remove 0-weighted points if weights don't match data") kwd['weights'] = w self.tmodel.fitdata = kwd finally: self.autoupdate = preaup self.updatemodelplot = True self.updatestats = True # def _tmodel_changed(self,old,new): # #old is only None before it is initialized # if new is not None and old is not None and new.model is not None: # self.fitmodel = True def _newmodel_fired(self,newval): from inspect import isclass if isinstance(newval,basestring) or isinstance(newval,FunctionModel1D) \ or (isclass(newval) and issubclass(newval,FunctionModel1D)): self.tmodel = TraitedModel(newval) else: if self.modelselector.edit_traits(kind='modal').result: cls = self.modelselector.selectedmodelclass if cls is None: self.tmodel = TraitedModel(None) elif self.modelselector.isvarargmodel: self.tmodel = TraitedModel(cls(self.modelselector.modelargnum)) self.fitmodel = True else: self.tmodel = TraitedModel(cls()) self.fitmodel = True else: #cancelled return def _showerror_fired(self,evt): if self.tmodel.lastfitfailure: ex = self.tmodel.lastfitfailure dialog = HasTraits(s=ex.__class__.__name__+': '+str(ex)) view = View(Item('s',style='custom',show_label=False), resizable=True,buttons=['OK'],title='Fitting error message') dialog.edit_traits(view=view) @cached_property def _get_chi2(self): try: return self.tmodel.model.chi2Data()[0] except: return 0 @cached_property def _get_chi2r(self): try: return self.tmodel.model.chi2Data()[1] except: return 0 def _get_nomodel(self): return self.tmodel.model is None def _get_weightsvary(self): w = self.weights return np.any(w!=w[0])if len(w)>0 else False def _get_plotname(self): xlabel = self.plot.x_axis.title ylabel = self.plot.y_axis.title if xlabel == '' and ylabel == '': return '' else: return xlabel+' vs '+ylabel def _set_plotname(self,val): if isinstance(val,basestring): val = val.split('vs') if len(val) ==1: val = val.split('-') val = [v.strip() for v in val] self.x_axis.title = val[0] self.y_axis.title = val[1] #selection-related def _scattertool_changed(self,old,new): if new == 'No Selection': self.plot.tools[0].drag_button='left' else: self.plot.tools[0].drag_button='right' if old is not None and 'lasso' in old: if new is not None and 'lasso' in new: #connect correct callbacks self.lassomode = new.replace('lasso','') return else: #TODO:test self.scatter.tools[-1].on_trait_change(self._lasso_handler, 'selection_changed',remove=True) del self.scatter.overlays[-1] del self.lassomode elif old == 'clickimmediate': self.scatter.index.on_trait_change(self._immediate_handler, 'metadata_changed',remove=True) self.scatter.tools = [] if new is None: pass elif 'click' in new: smodemap = {'clickimmediate':'single','clicksingle':'single', 'clicktoggle':'toggle'} self.scatter.tools.append(ScatterInspector(self.scatter, selection_mode=smodemap[new])) if new == 'clickimmediate': self.clearsel = True self.scatter.index.on_trait_change(self._immediate_handler, 'metadata_changed') elif 'lasso' in new: lasso_selection = LassoSelection(component=self.scatter, selection_datasource=self.scatter.index) self.scatter.tools.append(lasso_selection) lasso_overlay = LassoOverlay(lasso_selection=lasso_selection, component=self.scatter) self.scatter.overlays.append(lasso_overlay) self.lassomode = new.replace('lasso','') lasso_selection.on_trait_change(self._lasso_handler, 'selection_changed') lasso_selection.on_trait_change(self._lasso_handler, 'selection_completed') lasso_selection.on_trait_change(self._lasso_handler, 'updated') else: raise TraitsError('invalid scattertool value') def _weightchangesel_fired(self): self.weights[self.selectedi] = self.weightchangeto if self.unselectonaction: self.clearsel = True self._sel_alter_weights() self.lastselaction = 'weightchangesel' def _delsel_fired(self): self.weights[self.selectedi] = 0 if self.unselectonaction: self.clearsel = True self._sel_alter_weights() self.lastselaction = 'delsel' def _sel_alter_weights(self): if self.weighttype != 'custom': self._customweights = self.weights self.weighttype = 'custom' self.weightsChanged() def _clearsel_fired(self,event): if isinstance(event,list): self.scatter.index.metadata['selections'] = event else: self.scatter.index.metadata['selections'] = list() def _lasso_handler(self,name,new): if name == 'selection_changed': lassomask = self.scatter.index.metadata['selection'].astype(int) clickmask = np.zeros_like(lassomask) clickmask[self.scatter.index.metadata['selections']] = 1 if self.lassomode == 'add': mask = clickmask | lassomask elif self.lassomode == 'remove': mask = clickmask & ~lassomask elif self.lassomode == 'invert': mask = np.logical_xor(clickmask,lassomask) else: raise TraitsError('lassomode is in invalid state') self.scatter.index.metadata['selections'] = list(np.where(mask)[0]) elif name == 'selection_completed': self.scatter.overlays[-1].visible = False elif name == 'updated': self.scatter.overlays[-1].visible = True else: raise ValueError('traits event name %s invalid'%name) def _immediate_handler(self): sel = self.selectedi if len(sel) > 1: self.clearsel = True raise TraitsError('selection error in immediate mode - more than 1 selection') elif len(sel)==1: if self.lastselaction != 'None': setattr(self,self.lastselaction,True) del sel[0] def _savews_fired(self): self._savedws = self.weights.copy() def _loadws_fired(self): self.weights = self._savedws self._savews_fired() def _get_selectedi(self): return self.scatter.index.metadata['selections'] @on_trait_change('data,ytype',post_init=True) def dataChanged(self): """ Updates the application state if the fit data are altered - the GUI will know if you give it a new data array, but not if the data is changed in-place. """ pd = self.plot.data #TODO:make set_data apply to both simultaneously? pd.set_data('xdata',self.data[0]) pd.set_data('ydata',self.data[1]) self.updatemodelplot = False @on_trait_change('weights',post_init=True) def weightsChanged(self): """ Updates the application state if the weights/error bars for this model are changed - the GUI will automatically do this if you give it a new set of weights array, but not if they are changed in-place. """ weights = self.weights if 'errorplots' in self.trait_names(): #TODO:switch this to updating error bar data/visibility changing if self.errorplots is not None: self.plot.remove(self.errorplots[0]) self.plot.remove(self.errorplots[1]) self.errorbarplots = None if len(weights.shape)==2 and weights.shape[0]==2: xerr,yerr = 1/weights high = ArrayDataSource(self.scatter.index.get_data()+xerr) low = ArrayDataSource(self.scatter.index.get_data()-xerr) ebpx = ErrorBarPlot(orientation='v', value_high = high, value_low = low, index = self.scatter.value, value = self.scatter.index, index_mapper = self.scatter.value_mapper, value_mapper = self.scatter.index_mapper ) self.plot.add(ebpx) high = ArrayDataSource(self.scatter.value.get_data()+yerr) low = ArrayDataSource(self.scatter.value.get_data()-yerr) ebpy = ErrorBarPlot(value_high = high, value_low = low, index = self.scatter.index, value = self.scatter.value, index_mapper = self.scatter.index_mapper, value_mapper = self.scatter.value_mapper ) self.plot.add(ebpy) self.errorplots = (ebpx,ebpy) while len(weights.shape)>1: weights = np.sum(weights**2,axis=0) self.plot.data.set_data('weights',weights) self.plot.plots['data'][0].color_mapper.range.refresh() if self.weightsvary: if self.colorbar not in self.plotcontainer.components: self.plotcontainer.add(self.colorbar) self.plotcontainer.request_redraw() elif self.colorbar in self.plotcontainer.components: self.plotcontainer.remove(self.colorbar) self.plotcontainer.request_redraw() def _weighttype_changed(self, name, old, new): if old == 'custom': self._customweights = self.weights if new == 'custom': self.weights = self._customweights #if hasattr(self,'_customweights') else np.ones_like(self.data[0]) elif new == 'equal': self.weights = np.ones_like(self.data[0]) elif new == 'lin bins': self.weights = binned_weights(self.data[0],10,False) elif new == 'log bins': self.weights = binned_weights(self.data[0],10,True) else: raise TraitError('Invalid Enum value on weighttype') def getModelInitStr(self): """ Generates a python code string that can be used to generate a model with parameters matching the model in this :class:`FitGui`. :returns: initializer string """ mod = self.tmodel.model if mod is None: return 'None' else: parstrs = [] for p,v in mod.pardict.iteritems(): parstrs.append(p+'='+str(v)) if mod.__class__._pars is None: #varargs need to have the first argument give the right number varcount = len(mod.params)-len(mod.__class__._statargs) parstrs.insert(0,str(varcount)) return '%s(%s)'%(mod.__class__.__name__,','.join(parstrs)) def getModelObject(self): """ Gets the underlying object representing the model for this fit. :returns: The :class:`pymodelfit.core.FunctionModel1D` object. """ return self.tmodel.model
class ContourGridPlane(Module): # The version of this class. Used for persistence. __version__ = 0 # The grid plane component. grid_plane = Instance(GridPlane, allow_none=False, record=True) # Specifies if contouring is to be done or not. enable_contours = Bool(True, desc='if contours are generated') # The contour component that contours the data. contour = Instance(Contour, allow_none=False, record=True) # The actor component that represents the visualization. actor = Instance(Actor, allow_none=False, record=True) input_info = PipelineInfo( datasets=['image_data', 'structured_grid', 'rectilinear_grid'], attribute_types=['any'], attributes=['any']) view = View([ Group(Item(name='grid_plane', style='custom'), show_labels=False), Group(Item(name='enable_contours')), Group(Item(name='contour', style='custom', enabled_when='object.enable_contours'), Item(name='actor', style='custom'), show_labels=False) ]) ###################################################################### # `Module` interface ###################################################################### def setup_pipeline(self): """Override this method so that it *creates* the tvtk pipeline. This method is invoked when the object is initialized via `__init__`. Note that at the time this method is called, the tvtk data pipeline will *not* yet be setup. So upstream data will not be available. The idea is that you simply create the basic objects and setup those parts of the pipeline not dependent on upstream sources and filters. You should also set the `actors` attribute up at this point. """ # Create the components self.grid_plane = GridPlane() self.contour = Contour(auto_contours=True, number_of_contours=10) self.actor = Actor() def update_pipeline(self): """Override this method so that it *updates* the tvtk pipeline when data upstream is known to have changed. This method is invoked (automatically) when any of the inputs sends a `pipeline_changed` event. """ mm = self.module_manager if mm is None: return # Data is available, so set the input for the grid plane. self.grid_plane.inputs = [mm.source] # This makes sure that any changes made to enable_contours # when the module is not running are updated when it is # started. self._enable_contours_changed(self.enable_contours) # Set the LUT for the mapper. self.actor.set_lut(mm.scalar_lut_manager.lut) self.pipeline_changed = True def update_data(self): """Override this method so that it flushes the vtk pipeline if that is necessary. This method is invoked (automatically) when any of the inputs sends a `data_changed` event. """ # Just set data_changed, the components should do the rest if # they are connected. self.data_changed = True ###################################################################### # Non-public methods. ###################################################################### def _filled_contours_changed(self, value): """When filled contours are enabled, the mapper should use the the cell data, otherwise it should use the default scalar mode. """ if value: self.actor.mapper.scalar_mode = 'use_cell_data' else: self.actor.mapper.scalar_mode = 'default' self.render() def _enable_contours_changed(self, value): """Turns on and off the contours.""" if self.module_manager is None: return if value: self.actor.inputs = [self.contour] if self.contour.filled_contours: self.actor.mapper.scalar_mode = 'use_cell_data' else: self.actor.inputs = [self.grid_plane] self.actor.mapper.scalar_mode = 'default' self.render() def _grid_plane_changed(self, old, new): cont = self.contour if cont is not None: cont.inputs = [new] self._change_components(old, new) def _contour_changed(self, old, new): if old is not None: old.on_trait_change(self._filled_contours_changed, 'filled_contours', remove=True) new.on_trait_change(self._filled_contours_changed, 'filled_contours') # Setup the contours input. gp = self.grid_plane if gp is not None: new.inputs = [gp] # Setup the actor. actor = self.actor if actor is not None: actor.inputs = [new] self._change_components(old, new) def _actor_changed(self, old, new): if old is None: # First time this is set. new.property.set(line_width=2.0) # Set the actors scene and input. new.scene = self.scene cont = self.contour if cont is not None: new.inputs = [cont] self._change_components(old, new)
class SceneModel(TVTKScene): ######################################## # TVTKScene traits. light_manager = Property picker = Property ######################################## # SceneModel traits. # A convenient dictionary based interface to add/remove actors and widgets. # This is similar to the interface provided for the ActorEditor. actor_map = Dict() # This is used primarily to implement the add_actor/remove_actor methods. actor_list = List() # The actual scene being edited. scene_editor = Instance(TVTKScene) do_render = Event() # Fired when this is activated. activated = Event() # Fired when this widget is closed. closing = Event() # This exists just to mirror the TVTKWindow api. scene = Property ################################### # View related traits. # Render_window's view. _stereo_view = Group( Item(name='stereo_render'), Item(name='stereo_type'), show_border=True, label='Stereo rendering', ) # The default view of this object. default_view = View( Group(Group( Item(name='background'), Item(name='foreground'), Item(name='parallel_projection'), Item(name='disable_render'), Item(name='off_screen_rendering'), Item(name='jpeg_quality'), Item(name='jpeg_progressive'), Item(name='magnification'), Item(name='anti_aliasing_frames'), ), Group( Item(name='render_window', style='custom', visible_when='object.stereo', editor=InstanceEditor(view=View(_stereo_view)), show_label=False), ), label='Scene'), Group(Item(name='light_manager', style='custom', editor=InstanceEditor(), show_label=False), label='Lights')) ################################### # Private traits. # Used by the editor to determine if the widget was enabled or not. enabled_info = Dict() def __init__(self, parent=None, **traits): """ Initializes the object. """ # Base class constructor. We call TVTKScene's super here on purpose. # Calling TVTKScene's init will create a new window which we do not # want. super(TVTKScene, self).__init__(**traits) self.control = None ###################################################################### # TVTKScene API. ###################################################################### def render(self): """ Force the scene to be rendered. Nothing is done if the `disable_render` trait is set to True.""" self.do_render = True def add_actors(self, actors): """ Adds a single actor or a tuple or list of actors to the renderer.""" if hasattr(actors, '__iter__'): self.actor_list.extend(actors) else: self.actor_list.append(actors) def remove_actors(self, actors): """ Removes a single actor or a tuple or list of actors from the renderer.""" my_actors = self.actor_list if hasattr(actors, '__iter__'): for actor in actors: my_actors.remove(actor) else: my_actors.remove(actors) # Conevenience methods. add_actor = add_actors remove_actor = remove_actors def add_widgets(self, widgets, enabled=True): """Adds widgets to the renderer. """ if not hasattr(widgets, '__iter__'): widgets = [widgets] for widget in widgets: self.enabled_info[widget] = enabled self.add_actors(widgets) def remove_widgets(self, widgets): """Removes widgets from the renderer.""" if not hasattr(widgets, '__iter__'): widgets = [widgets] self.remove_actors(widgets) for widget in widgets: del self.enabled_info[widget] def reset_zoom(self): """Reset the camera so everything in the scene fits.""" if self.scene_editor is not None: self.scene_editor.reset_zoom() def save(self, file_name, size=None, **kw_args): """Saves rendered scene to one of several image formats depending on the specified extension of the filename. If an additional size (2-tuple) argument is passed the window is resized to the specified size in order to produce a suitably sized output image. Please note that when the window is resized, the window may be obscured by other widgets and the camera zoom is not reset which is likely to produce an image that does not reflect what is seen on screen. Any extra keyword arguments are passed along to the respective image format's save method. """ self._check_scene_editor() self.scene_editor.save(file_name, size, **kw_args) def save_ps(self, file_name): """Saves the rendered scene to a rasterized PostScript image. For vector graphics use the save_gl2ps method.""" self._check_scene_editor() self.scene_editor.save_ps(file_name) def save_bmp(self, file_name): """Save to a BMP image file.""" self._check_scene_editor() self.scene_editor.save_bmp(file_name) def save_tiff(self, file_name): """Save to a TIFF image file.""" self._check_scene_editor() self.scene_editor.save_tiff(file_name) def save_png(self, file_name): """Save to a PNG image file.""" self._check_scene_editor() self.scene_editor.save_png(file_name) def save_jpg(self, file_name, quality=None, progressive=None): """Arguments: file_name if passed will be used, quality is the quality of the JPEG(10-100) are valid, the progressive arguments toggles progressive jpegs.""" self._check_scene_editor() self.scene_editor.save_jpg(file_name, quality, progressive) def save_iv(self, file_name): """Save to an OpenInventor file.""" self._check_scene_editor() self.scene_editor.save_iv(file_name) def save_vrml(self, file_name): """Save to a VRML file.""" self._check_scene_editor() self.scene_editor.save_vrml(file_name) def save_oogl(self, file_name): """Saves the scene to a Geomview OOGL file. Requires VTK 4 to work.""" self._check_scene_editor() self.scene_editor.save_oogl(file_name) def save_rib(self, file_name, bg=0, resolution=None, resfactor=1.0): """Save scene to a RenderMan RIB file. Keyword Arguments: file_name -- File name to save to. bg -- Optional background option. If 0 then no background is saved. If non-None then a background is saved. If left alone (defaults to None) it will result in a pop-up window asking for yes/no. resolution -- Specify the resolution of the generated image in the form of a tuple (nx, ny). resfactor -- The resolution factor which scales the resolution. """ self._check_scene_editor() self.scene_editor.save_rib(file_name, bg, resolution, resfactor) def save_wavefront(self, file_name): """Save scene to a Wavefront OBJ file. Two files are generated. One with a .obj extension and another with a .mtl extension which contains the material proerties. Keyword Arguments: file_name -- File name to save to """ self._check_scene_editor() self.scene_editor.save_wavefront(file_name) def save_gl2ps(self, file_name, exp=None): """Save scene to a vector PostScript/EPS/PDF/TeX file using GL2PS. If you choose to use a TeX file then note that only the text output is saved to the file. You will need to save the graphics separately. Keyword Arguments: file_name -- File name to save to. exp -- Optionally configured vtkGL2PSExporter object. Defaults to None and this will use the default settings with the output file type chosen based on the extention of the file name. """ self._check_scene_editor() self.scene_editor.save_gl2ps(file_name, exp) def get_size(self): """Return size of the render window.""" self._check_scene_editor() return self.scene_editor.get_size() def set_size(self, size): """Set the size of the window.""" self._check_scene_editor() self.scene_editor.set_size(size) def _update_view(self, x, y, z, vx, vy, vz): """Used internally to set the view.""" if self.scene_editor is not None: self.scene_editor._update_view(x, y, z, vx, vy, vz) def _check_scene_editor(self): if self.scene_editor is None: msg = """ This method requires that there be an active scene editor. To do this, you will typically need to invoke:: object.edit_traits() where object is the object that contains the SceneModel. """ raise SceneModelError(msg) def _scene_editor_changed(self, old, new): if new is None: self._renderer = None self._renwin = None self._interactor = None else: self._renderer = new._renderer self._renwin = new._renwin self._interactor = new._interactor def _get_picker(self): """Getter for the picker.""" se = self.scene_editor if se is not None and hasattr(se, 'picker'): return se.picker return None def _get_light_manager(self): """Getter for the light manager.""" se = self.scene_editor if se is not None: return se.light_manager return None ###################################################################### # SceneModel API. ###################################################################### def _get_scene(self): """Getter for the scene property.""" return self
""" Traits View definition file. The view trait of the parent class is extracted from the model definition file. This file can either be exec()ed or imported. See core/base.py:Base.trait_view() for what is currently used. Using exec() allows view changes without needing to restart Mayavi, but is slower than importing. """ # Authors: Prabhu Ramachandran <*****@*****.**> # Vibha Srinivasan <*****@*****.**> # Judah De Paula <*****@*****.**> # Copyright (c) 2005-2008, Enthought, Inc. # License: BSD Style. from enthought.traits.ui.api import Item, Group, View view = View(Group(Item(name='function'), Item(name='parametric_function', style='custom', resizable=True), label='Function', show_labels=False), Group(Item(name='source', style='custom', resizable=True), label='Source', show_labels=False), resizable=True)
class CMPGUI(PipelineConfiguration): """ The Graphical User Interface for the CMP """ def __init__(self, **kwargs): # NOTE: In python 2.6, object.__init__ no longer accepts input # arguments. HasTraits does not define an __init__ and # therefore these args were being ignored. super(CMPGUI, self).__init__(**kwargs) about = Button run = Button save = Button load = Button help = Button inspect_registration = Button inspect_segmentation = Button inspect_whitemattermask = Button inspect_parcellation = Button inspect_reconstruction = Button inspect_tractography = Button inspect_tractography_filtered = Button inspect_fiberfilter = Button inspect_connectionmatrix = Button main_group = Group( VGroup(Item('project_dir', label='Project Directory:', tooltip='Please select the root folder of your project'), Item( 'generator', label='Generator', ), Item('diffusion_imaging_model', label='Imaging Modality'), label="Project Settings"), HGroup( VGroup( Item('active_dicomconverter', label='DICOM Converter', tooltip="converts DICOM to the Nifti format"), Item('active_registration', label='Registration'), Item('active_segmentation', label='Segmentation'), Item('active_parcellation', label='Parcellation'), Item('active_applyregistration', label='Apply registration'), Item('active_reconstruction', label='Reconstruction'), Item('active_tractography', label='Tractography', tooltip='performs tractography'), Item('active_fiberfilter', label='Fiber Filtering', tooltip='applies filtering operation to the fibers'), Item('active_connectome', label='Connectome Creation', tooltip='creates the connectivity matrices'), # Item('active_statistics', label = 'Statistics'), Item('active_cffconverter', label='CFF Converter', tooltip='converts processed files to a connectome file'), Item('skip_completed_stages', label='Skip Previously Completed Stages:'), label="Stages"), VGroup( #Item('inspect_rawT1', label = 'Inspect Raw T1', show_label = False), #Item('inspect_rawdiff', label = 'Inspect Raw Diffusion', show_label = False), Item('inspect_registration', label='Registration', show_label=False), Item('inspect_segmentation', label='Segmentation', show_label=False), #Item('inspect_whitemattermask', label = 'White Matter Mask', show_label = False), Item('inspect_parcellation', label='Parcellation', show_label=False), #Item('inspect_reconstruction', label = 'Reconstruction', show_label = False), # DTB_viewer Item('inspect_tractography', label='Tractography Original', show_label=False), Item('inspect_tractography_filtered', label='Tractography Filtered', show_label=False), Item('inspect_connectionmatrix', label='Connection Matrix', show_label=False), label="Inspector") #VGroup( #label="Status", #) ), label="Main", show_border=False) metadata_group = Group(VGroup( Item('creator', label="Creator"), Item('email', label="E-Mail"), Item('publisher', label="Publisher"), Item('created', label="Creation Date"), Item('modified', label="Modification Date"), Item('license', label="License"), Item('rights', label="Rights"), Item('reference', label="References"), Item('relation', label="Relations"), Item('species', label="Species"), Item('description', label="Project Description"), ), label="Metadata", show_border=False) subject_group = Group(VGroup(Item('subject_name', label="Name"), Item('subject_timepoint', label="Timepoint"), Item('subject_workingdir', label="Working Directory"), Item('subject_metadata', label='Metadata', editor=table_editor), show_border=True), label="Subject") dicomconverter_group = Group(VGroup( Item('do_convert_diffusion', label="Convert Diffusion data?"), Item('subject_raw_glob_diffusion', label="Diffusion File Pattern", enabled_when='do_convert_diffusion'), Item('do_convert_T1', label="Convert T1 data?"), Item('subject_raw_glob_T1', label="T1 File Pattern", enabled_when='do_convert_T1'), Item('do_convert_T2', label="Convert T2 data?"), Item('subject_raw_glob_T2', label="T2 File Pattern", enabled_when='do_convert_T2'), Item('extract_diffusion_metadata', label="Try extracting Diffusion metadata"), show_border=True), visible_when="active_dicomconverter", label="DICOM Converter") registration_group = Group( VGroup(Item('registration_mode', label="Registration"), VGroup(Item('lin_reg_param', label='FLIRT Parameters'), enabled_when='registration_mode == "Linear"', label="Linear Registration"), VGroup(Item('nlin_reg_bet_T2_param', label="BET T2 Parameters"), Item('nlin_reg_bet_b0_param', label="BET b0 Parameters"), Item('nlin_reg_fnirt_param', label="FNIRT Parameters"), enabled_when='registration_mode == "Nonlinear"', label="Nonlinear Registration"), show_border=True, enabled_when="active_registration"), visible_when="active_registration", label="Registration", ) parcellation_group = Group( VGroup( Item('parcellation_scheme', label="Parcellation Scheme"), # VGroup( # Item('custompar_nrroi', label="Number of ROI"), # Item('custompar_nodeinfo', label="Node Information (GraphML)"), # Item('custompar_volumeparcell', label="Volumetric parcellation"), # enabled_when = 'parcellation_scheme == "custom"', # label = "Custom Parcellation" # ), # show_border = True, # enabled_when = "active_registration" ), visible_when="active_parcellation", label="Parcellation", ) reconstruction_group = Group( VGroup(Item('nr_of_gradient_directions', label="Number of Gradient Directions"), Item('nr_of_sampling_directions', label="Number of Sampling Directions"), Item('nr_of_b0', label="Number of b0 volumes"), Item('odf_recon_param', label="odf_recon Parameters"), Item('dtb_dtk2dir_param', label="DTB_dtk2dir Parameters"), show_border=True, visible_when="diffusion_imaging_model == 'DSI'"), VGroup(Item('gradient_table', label="Gradient Table"), Item('gradient_table_file', label="Gradient Table File"), Item('nr_of_b0', label="Number of b0 volumes"), Item('max_b0_val', label="Maximum b value"), Item('dti_recon_param', label="dti_recon Parameters"), Item('dtb_dtk2dir_param', label="DTB_dtk2dir Parameters"), show_border=True, visible_when="diffusion_imaging_model == 'DTI'"), VGroup( Item('gradient_table', label="Gradient Table"), Item('gradient_table_file', label="Gradient Table File"), Item('nr_of_gradient_directions', label="Number of Gradient Directions"), Item('nr_of_sampling_directions', label="Number of Sampling Directions"), Item('nr_of_b0', label="Number of b0 volumes"), #Item('max_b0_val', label="Maximumb b value"), Item('hardi_recon_param', label="odf_recon Parameters"), Item('dtb_dtk2dir_param', label="DTB_dtk2dir Parameters"), show_border=True, visible_when="diffusion_imaging_model == 'QBALL'"), visible_when="active_reconstruction", label="Reconstruction", ) segementation_group = Group( VGroup( Item('recon_all_param', label="recon_all Parameters"), show_border=True, ), enabled_when="active_segmentation", visible_when="active_segmentation", label="Segmentation", ) tractography_group = Group( VGroup( Item('streamline_param', label="DTB_streamline Parameters"), show_border=True, visible_when= "diffusion_imaging_model == 'DSI' or diffusion_imaging_model == 'QBALL'", ), VGroup(Item('streamline_param_dti', label="dti_tracker Parameters"), show_border=True, visible_when="diffusion_imaging_model == 'DTI'"), enabled_when="active_tractography", visible_when="active_tractography", label="Tractography", ) fiberfilter_group = Group( VGroup(Item('apply_splinefilter', label="Apply spline filter"), Item('apply_fiberlength', label="Apply cutoff filter"), Item('fiber_cutoff_lower', label='Lower cutoff length (mm)', enabled_when='apply_fiberlength'), Item('fiber_cutoff_upper', label='Upper cutoff length (mm)', enabled_when='apply_fiberlength'), show_border=True, enabled_when="active_fiberfilter"), visible_when="active_fiberfilter", label="Fiber Filtering", ) connectioncreation_group = Group( VGroup(Item('compute_curvature', label="Compute curvature"), show_border=True, enabled_when="active_connectome"), visible_when="active_connectome", label="Connectome Creation", ) cffconverter_group = Group( VGroup( Item('cff_fullnetworkpickle', label="All connectomes"), # Item('cff_cmatpickle', label='cmat.pickle'), Item('cff_originalfibers', label="Original Tractography"), Item('cff_filteredfibers', label="Filtered Tractography"), Item('cff_fiberarr', label="Filtered fiber arrays"), Item('cff_finalfiberlabels', label="Final Tractography and Labels"), Item('cff_scalars', label="Scalar maps"), Item('cff_rawdiffusion', label="Raw Diffusion Data"), Item('cff_rawT1', label="Raw T1 data"), Item('cff_rawT2', label="Raw T2 data"), Item('cff_roisegmentation', label="Parcellation Volumes"), Item('cff_surfaces', label="Surfaces", tooltip='stores individually generated surfaces'), #Item('cff_surfacelabels', label="Surface labels", tooltip = 'stores the labels on the surfaces'), show_border=True, ), visible_when="active_cffconverter", label="CFF Converter", ) configuration_group = Group( VGroup( Item('emailnotify', label='E-Mail Notification'), #Item('wm_handling', label='White Matter Mask Handling', tooltip = """1: run through the freesurfer step without stopping #2: prepare whitematter mask for correction (store it in subject dir/NIFTI #3: rerun freesurfer part with corrected white matter mask"""), Item('freesurfer_home', label="Freesurfer Home"), Item('fsl_home', label="FSL Home"), Item('dtk_home', label="DTK Home"), show_border=True, ), label="Configuration", ) view = View( Group( HGroup(main_group, metadata_group, subject_group, dicomconverter_group, registration_group, segementation_group, parcellation_group, reconstruction_group, tractography_group, fiberfilter_group, connectioncreation_group, cffconverter_group, configuration_group, orientation='horizontal', layout='tabbed', springy=True), spring, HGroup( Item('about', label='About', show_label=False), Item('help', label='Help', show_label=False), Item('save', label='Save State', show_label=False), Item('load', label='Load State', show_label=False), spring, Item('run', label='Map Connectome!', show_label=False), ), ), resizable=True, width=0.3, handler=CMPGUIHandler, title='Connectome Mapper', ) def _about_fired(self): a = HelpDialog() a.configure_traits(kind='livemodal') def _help_fired(self): a = HelpDialog() a.configure_traits(kind='livemodal') def load_state(self, cmpconfigfile): """ Load CMP Configuration state directly. Useful if you do not want to invoke the GUI""" import enthought.sweet_pickle as sp output = open(cmpconfigfile, 'rb') data = sp.load(output) self.__setstate__(data.__getstate__()) # make sure that dtk_matrices is set self.dtk_matrices = os.path.join(self.dtk_home, 'matrices') # update the subject directory if os.path.exists(self.project_dir): self.subject_workingdir = os.path.join(self.project_dir, self.subject_name, self.subject_timepoint) output.close() def save_state(self, cmpconfigfile): """ Save CMP Configuration state directly. Useful if you do not want to invoke the GUI Parameters ---------- cmpconfigfile : string Absolute path and filename to store the CMP configuration pickled object """ # check if path available if not os.path.exists(os.path.dirname(cmpconfigfile)): os.makedirs(os.path.abspath(os.path.dirname(cmpconfigfile))) import enthought.sweet_pickle as sp output = open(cmpconfigfile, 'wb') # Pickle the list using the highest protocol available. # copy object first tmpconf = CMPGUI() tmpconf.copy_traits(self) sp.dump(tmpconf, output, -1) output.close() def show(self): """ Show the GUI """ #self.configure_traits() self.edit_traits(kind='livemodal') # def _gradient_table_file_default(self): # return self.get_gradient_table_file() # XXX this is not automatically invoked! def _get_gradient_table_file(self): if self.gradient_table == 'custom': gradfile = self.get_custom_gradient_table() else: gradfile = self.get_cmp_gradient_table(self.gradient_table) if not os.path.exists(gradfile): msg = 'Selected gradient table %s does not exist!' % gradfile raise Exception(msg) return gradfile def _project_dir_changed(self, value): self.subject_workingdir = value def _subject_name_changed(self, value): self.subject_workingdir = os.path.join(self.project_dir, value, self.subject_timepoint) def _subject_timepoint_changed(self, value): self.subject_workingdir = os.path.join(self.project_dir, self.subject_name, value) def _gradient_table_changed(self, value): if value == 'custom': self.gradient_table_file = self.get_custom_gradient_table() else: self.gradient_table_file = self.get_cmp_gradient_table(value) if not os.path.exists(self.gradient_table_file): msg = 'Selected gradient table %s does not exist!' % self.gradient_table_file raise Exception(msg) def _parcellation_scheme_changed(self, value): if value == "Lausanne2008": self.parcellation = self._get_lausanne_parcellation( parcel="Lausanne2008") else: self.parcellation = self._get_lausanne_parcellation( parcel="NativeFreesurfer") def _inspect_registration_fired(self): cmp.registration.inspect(self) def _inspect_tractography_fired(self): cmp.tractography.inspect(self) def _inspect_tractography_filtered_fired(self): cmp.fiberfilter.inspect(self) def _inspect_segmentation_fired(self): cmp.freesurfer.inspect(self) def _inspect_parcellation_fired(self): cmp.maskcreation.inspect(self) def _inspect_connectionmatrix_fired(self): cmp.connectionmatrix.inspect(self) def _run_fired(self): pass # execute the pipeline thread # first do a consistency check #self.consistency_check() # otherwise store the pickle #self.save_state(os.path.join(self.get_log(), self.get_logname(suffix = '.pkl')) ) # hide the gui # run the pipeline #print "mapit" #cmp.connectome.mapit(self) # show the gui #cmpthread = CMPThread(self) #cmpthread.start() def _load_fired(self): import enthought.sweet_pickle as sp from enthought.pyface.api import FileDialog, OK wildcard = "CMP Configuration State (*.pkl)|*.pkl|" \ "All files (*.*)|*.*" dlg = FileDialog(wildcard=wildcard,title="Select a configuration state to load",\ resizeable=False, \ default_directory=self.project_dir,) if dlg.open() == OK: if not os.path.isfile(dlg.path): return else: self.load_state(dlg.path) def _save_fired(self): import pickle import enthought.sweet_pickle as sp import os.path from enthought.pyface.api import FileDialog, OK wildcard = "CMP Configuration State (*.pkl)|*.pkl|" \ "All files (*.*)|*.*" dlg = FileDialog(wildcard=wildcard,title="Filename to store configuration state",\ resizeable=False, action = 'save as', \ default_directory=self.subject_workingdir,) if dlg.open() == OK: if not dlg.path.endswith('.pkl'): dlg.path = dlg.path + '.pkl' self.save_state(dlg.path)
class FitProcessor(HasTraits): """ A traits based class for simplifying multidimensional nonlinear least squares function fitting. """ fit_data = Instance(FitData) #holds the data fit_model = Instance( FitModel ) #holds the data selection, fit parameters, and functional model error_func = Function #automatically generated error function optimizer = Instance( NLSOptimizer ) #optimizer for error function, obtains best fit parameters iter_num = Int(0) #keep track of the optimizer iterations fit_log = Str("") #stores the info from fitting in a YAML format view = View(Item('optimizer', label='Optimizer', style='custom'), resizable=True, height=0.75, width=0.25) #-------------------------------------------------------------------------- #@on_trait_change('fit_model') def update_error_func(self): #freeze out copies of data and parameters fp_names = self.fit_model.get_free_param_names() pdict = self.fit_model.get_params_dict() X, Y, W = self.fit_data.get_selection() #create the function closure on the free (varied) parameter set func = self.fit_model.evaluate def varied_func(p): pdict.update( dict([(fp_name, val) for fp_name, val in zip(fp_names, p)])) return func(X=X, pdict=pdict) #create the error function closure: def error_func(p): F = varied_func( p) #evaulate the varied function on the parameter set, p errs = [ (y - f) * w for y, f, w in zip(Y, F, W) ] #pair each data row with its function evalutation element and weighting errs = hstack( errs) #create a lumped row vector of all the deviations return errs self.error_func = error_func #-------------------------------------------------------------------------- def fit(self): "run the optimizer on the error function to obtain best fit parameters" "" error_func = self.error_func P0 = self.fit_model.get_free_param_values() self._clear_log() #empty the fitting log self._print_log("## Starting Fit ##") if len(P0) > 0: #free parameter set cannot be empty for fitting self.iter_num = 0 #reset iteration counter self.optimizer = NLSOptimizer(cost_map=error_func, P0=P0) self.optimizer.optimize() #determine if fitting was successful success = self.optimizer.success msg = self.optimizer.message if success: #update the parameters fp_values = self.optimizer.P fp_names = self.fit_model.get_free_param_names() for name, value in zip(fp_names, fp_values): self.fit_model.update_param(name, value=value) #compute the error on the parameters, if possible err = self.optimizer.cost ndf = self.optimizer.ndf reduced_chisqr = (err * err).sum() / ndf covar = self.optimizer.covar if not covar is Undefined: covar *= reduced_chisqr #rescale the covariance matrix p_var = covar.diagonal() p_err = sqrt(p_var) for name, error in zip(fp_names, p_err): self.fit_model.update_param(name, error=error) self._print_log("## Fitting Completed ##") self._print_log("---") self._print_log("parameters:") self._print_log(self.fit_model.params, level=2, indent=' ') self._print_log("ndf: %d" % ndf) self._print_log("reduced_chisqr: %g" % reduced_chisqr) else: self._print_log("## Fitting Failed! ##") self._print_log("---") self._print_log("ierr: %s" % self.optimizer.ier) self._print_log("message: %s" % msg) self._print_log("...") else: #empty free parameter set, do not fit pass def _clear_log(self): self.fit_log = "" def _print_log(self, text, indent="\t", level=0, newline='\n'): text = str(text) if level >= 1: #reformat the text to indent it text_lines = text.split(newline) space = indent * level text_lines = ["%s%s" % (space, line) for line in text_lines] text = newline.join(text_lines) self.fit_log += text + newline
class CheckListTest(Handler): #--------------------------------------------------------------------------- # Trait definitions: #--------------------------------------------------------------------------- case = Enum('Colors', 'Numbers') value = List(editor=CheckListEditor(values=colors, cols=5)) #--------------------------------------------------------------------------- # Event handlers: #--------------------------------------------------------------------------- def object_case_changed(self, info): if self.case == 'Colors': info.value.factory.values = colors else: info.value.factory.values = numbers #------------------------------------------------------------------------------- # Run the tests: #------------------------------------------------------------------------------- if __name__ == '__main__': clt = CheckListTest() clt.configure_traits(view=View('case', '_', Item('value', id='value'))) print 'value:', clt.value clt.configure_traits(view=View('case', '_', Item('value@', id='value'))) print 'value:', clt.value
class Figure(HasTraits): figure = Instance(MPL_Figure, transient=True) process_selection = Function view = View(Item('figure', editor=MPLFigureEditor(), width=400, show_label=False, height=300), resizable=True) def _process_selection_default(self): def f(point0, point1): pass return f def _figure_default(self): self.figure = MPL_Figure() image = zeros(shape=(300, 400), dtype='uint8') self.update_image(image) return figure def append_selector(self): def line_select_callback(event1, event2): 'event1 and event2 are the press and release events' pos1 = event1.xdata, event1.ydata pos2 = event2.xdata, event2.ydata self.process_selection(pos1, pos2) ax = self.figure.add_subplot(111) RectangleSelector(ax, line_select_callback, drawtype='box', useblit=True, minspanx=0, minspany=0, spancoords='pixels') def update_image(self, data): ax = self.figure.add_subplot(111) ax.set_autoscale_on(True) ax.images = [] try: ax.imshow(data, interpolation='nearest') self.append_selector() except: pass finally: ax.set_autoscale_on(False) self.figure.canvas.draw() def plot_data(self, x, y, name='data 0', color='black'): ax = self.figure.add_subplot(111) ax.plot(x, y, color) ax.text(x[0], y[0], name, color=color) self.figure.canvas.draw() def del_plot(self, name): if name == 'all': ax = self.figure.add_subplot(111) ax.lines = [] ax.texts = []
class CheckListEditorDemo(HasTraits): """ Define the main CheckListEditor demo class. """ # Define a trait for each of three formations: checklist_4col = List( editor=CheckListEditor(values=['one', 'two', 'three', 'four'], cols=4)) checklist_2col = List( editor=CheckListEditor(values=['one', 'two', 'three', 'four'], cols=2)) checklist_1col = List( editor=CheckListEditor(values=['one', 'two', 'three', 'four'], cols=1)) # CheckListEditor display with four columns: cl_4_group = Group(Item('checklist_4col', style='simple', label='Simple'), Item('_'), Item('checklist_4col', style='custom', label='Custom'), Item('_'), Item('checklist_4col', style='text', label='Text'), Item('_'), Item('checklist_4col', style='readonly', label='ReadOnly'), label='4-column') # CheckListEditor display with two columns: cl_2_group = Group(Item('checklist_2col', style='simple', label='Simple'), Item('_'), Item('checklist_2col', style='custom', label='Custom'), Item('_'), Item('checklist_2col', style='text', label='Text'), Item('_'), Item('checklist_2col', style='readonly', label='ReadOnly'), label='2-column') # CheckListEditor display with one column: cl_1_group = Group(Item('checklist_1col', style='simple', label='Simple'), Item('_'), Item('checklist_1col', style='custom', label='Custom'), Item('_'), Item('checklist_1col', style='text', label='Text'), Item('_'), Item('checklist_1col', style='readonly', label='ReadOnly'), label='1-column') # The view includes one group per column formation. These will be displayed # on separate tabbed panels. view1 = View(cl_4_group, cl_2_group, cl_1_group, title='CheckListEditor', buttons=['OK'], resizable=True)
class BasePlot(HasTraits): """ An interface defining an object which can render a plot on a figure object """ implements(IPlot) n_x = Int(1) n_y = Int(1) figure = Instance(Figure, ()) view = View( Item( 'figure', #height = 600, #width = 800, style='custom', show_label=False, editor=MPLFigureEditor( ), #this editor will automatically find and connect the _handle_onpick method for handling matplotlib's object picking events ) ), ) def clear(self): self.figure.clear() def render(self, Xs, Ys, fmts=None, labels=None, pickable=[], **kwargs): ''' Plots data from 'Xs', 'Ys' on 'figure' and returns the figure object''' data = self._convert_data(Xs, Ys) Xs = data['Xs'] Ys = data['Ys'] if fmts is None: fmts = [] if labels is None: labels = [] axes = self.figure.add_subplot(111) #kwargs['axes'] = axes #kwargs['figure'] = self.figure for X, Y, fmt, label in map(None, Xs, Ys, fmts, labels): if not (X is None or Y is None): kwargs['label'] = label self._plot(X, Y, fmt, axes=axes, **kwargs) if labels: axes.legend() #set up the plot point object selection for ind in pickable: line = axes.lines[ind] line.set_picker(5.0) def redraw(self): if not self.figure.canvas is None: self.figure.canvas.draw() def register_onpick_handler(self, handler): self._handle_onpick = handler def _plot(self, x, y, fmt=None, axes=None, **kwargs): if axes is None: raise TypeError, "an 'axes' object must be supplied" if fmt is None: axes.plot(x, y, **kwargs) else: axes.plot(x, y, fmt, **kwargs) def _convert_data(self, Xs=None, Ys=None): #convert the data for the independent variables data_args = {'Xs': (Xs, self.n_x), 'Ys': (Ys, self.n_y)} data = {} for name, args in data_args.items(): D, n = args #data array, expected number of variables if not D is None: for d in D: print d.shape D = array(D) #convert to a numpy array dim = len(D.shape) if dim == 1: if n == 1: #upconvert 1D array to 2D D = D.reshape((1, -1)) else: raise TypeError, "'%s' dimension must be 2 or 3 for n > 1, detected incommensurate data of dimension %d" % ( name, dim) elif dim == 2: d1, d2 = D.shape if n == 1: pass #no conversion needed elif not (d1 == n): raise TypeError, "'%s' shape (%d,%d) must match (n=%d,:)" % ( name, d1, d2, n) else: #up convert 2D array to 3D D = D.reshape((1, d1, d2)) elif dim == 3: d1, d2, d3 = D.shape if n == 1 and d2 == 1: #down convert 3D array to 2D D = D.reshape((d1, d3)) elif not (d2 == n): raise TypeError, "'%s' shape (%d,%d,%d) must match (:,n=%d,:)" % ( name, d1, d2, d3, n) else: raise TypeError, "'%s' dimension must be 1, 2 or 3, detected incommensurate data of dimension %d" % ( name, dim) data[name] = D else: #default to an empty array data[name] = array([]) return data
from enthought.traits.api import HasTraits, Str, Int, Bool from enthought.traits.ui.api import View, Group, Item #--[Code]----------------------------------------------------------------------- # Sample class class House(HasTraits): address = Str bedrooms = Int pool = Bool price = Int # View object designed to display two objects of class 'House' comp_view = View(Group(Group(Item('h1.address', resizable=True), Item('h1.bedrooms'), Item('h1.pool'), Item('h1.price'), show_border=True), Group(Item('h2.address', resizable=True), Item('h2.bedrooms'), Item('h2.pool'), Item('h2.price'), show_border=True), orientation='horizontal'), title='House Comparison') # A pair of houses to demonstrate the View house1 = House(address='4743 Dudley Lane', bedrooms=3,
class ScatterPlotNM(MutableTemplate): #-- Template Traits -------------------------------------------------------- # The title of the plot: title = TStr('NxM Scatter Plots') # The type of marker to use. This is a mapped trait using strings as the # keys: marker = marker_trait(template='copy', event='update') # The pixel size of the marker (doesn't include the thickness of the # outline): marker_size = TRange(1, 5, 1, event='update') # The thickness, in pixels, of the outline to draw around the marker. If # this is 0, no outline will be drawn. line_width = TRange(0.0, 5.0, 1.0) # The fill color of the marker: color = TColor('red', event='update') # The color of the outline to draw around the marker outline_color = TColor('black', event='update') # The number of rows of plots: rows = TRange(1, 3, 1, event='grid') # The number of columns of plots: columns = TRange(1, 5, 1, event='grid') # The contained scatter plots: scatter_plots = TList(ScatterPlot) #-- Derived Traits --------------------------------------------------------- plot = TDerived #-- Traits UI Views -------------------------------------------------------- # The scatter plot view: template_view = View(VGroup( Item('title', show_label=False, style='readonly', editor=ThemedTextEditor(theme=Theme('@GBB', alignment='center'))), Item('plot', show_label=False, resizable=True, editor=EnableEditor(), item_theme=Theme('@GF5', margins=0))), resizable=True) # The scatter plot options view: options_view = View( VGroup( VGroup(Label('Scatter Plot Options', item_theme=Theme('@GBB', alignment='center')), show_labels=False), VGroup(Item('title', editor=TextEditor()), Item('marker'), Item('marker_size', editor=ThemedSliderEditor()), Item('line_width', label='Line Width', editor=ThemedSliderEditor()), Item('color', label='Fill Color'), Item('outline_color', label='Outline Color'), Item('rows', editor=ThemedSliderEditor()), Item('columns', editor=ThemedSliderEditor()), group_theme=Theme('@GF5', margins=(-5, -1)), item_theme=Theme('@G0B', margins=0)))) #-- ITemplate Interface Implementation ------------------------------------- def activate_template(self): """ Converts all contained 'TDerived' objects to real objects using the template traits of the object. This method must be overridden in subclasses. Returns ------- None """ plots = [] i = 0 for r in range(self.rows): row = [] for c in range(self.columns): plot = self.scatter_plots[i].plot if plot is None: plot = PlotComponent() row.append(plot) i += 1 plots.append(row) self.plot = GridPlotContainer(shape=(self.rows, self.columns)) self.plot.component_grid = plots #-- Default Values --------------------------------------------------------- def _scatter_plots_default(self): """ Returns the default value for the scatter plots list. """ plots = [] for i in range(self.rows * self.columns): plots.append(ScatterPlot()) self._update_plots(plots) return plots #-- Trait Event Handlers --------------------------------------------------- def _update_changed(self, name, old, new): """ Handles a plot option being changed. """ for sp in self.scatter_plots: setattr(sp, name, new) self.plot = Undefined def _grid_changed(self): """ Handles the grid size being changed. """ n = self.rows * self.columns plots = self.scatter_plots if n < len(plots): self.scatter_plots = plots[:n] else: for j in range(len(plots), n): plots.append(ScatterPlot()) self._update_plots(plots) self.template_mutated = True #-- Private Methods -------------------------------------------------------- def _update_plots(self, plots): """ Update the data sources for all of the current plots. """ index = None i = 0 for r in range(self.rows): for c in range(self.columns): sp = plots[i] i += 1 desc = sp.value.description col = desc.rfind('[') if col >= 0: desc = desc[:col] sp.value.description = '%s[%d,%d]' % (desc, r, c) sp.value.optional = True if index is None: index = sp.index index.description = 'Shared Plot Index' index.optional = True else: sp.index = index
def get_item(self): return Item(name=self.name, style=self.style, visible_when=self.visible_when, format_func=lambda v: '%.2f cm' % v)
class Graph(HasTraits): """ 绘图组件,包括左边的数据选择控件和右边的绘图控件 """ name = Str # 绘图名,显示在标签页标题和绘图标题中 data_source = Instance(DataSource) # 保存数据的数据源 figure = Instance(Figure) # 控制绘图控件的Figure对象 selected_xaxis = Str # X轴所用的数据名 selected_items = List # Y轴所用的数据列表 clear_button = Button(u"清除") # 快速清除Y轴的所有选择的数据 view = View( HSplit( # HSplit分为左右两个区域,中间有可调节宽度比例的调节手柄 # 左边为一个组 VGroup( Item("name"), # 绘图名编辑框 Item("clear_button"), # 清除按钮 Heading(u"X轴数据"), # 静态文本 # X轴选择器,用EnumEditor编辑器,即ComboBox控件,控件中的候选数据从 # data_source的names属性得到 Item("selected_xaxis", editor= EnumEditor(name="object.data_source.names", format_str=u"%s")), Heading(u"Y轴数据"), # 静态文本 # Y轴选择器,由于Y轴可以多选,因此用CheckBox列表编辑,按两列显示 Item("selected_items", style="custom", editor=CheckListEditor(name="object.data_source.names", cols=2, format_str=u"%s")), show_border = True, # 显示组的边框 scrollable = True, # 组中的控件过多时,采用滚动条 show_labels = False # 组中的所有控件都不显示标签 ), # 右边绘图控件 Item("figure", editor=MPLFigureEditor(), show_label=False, width=600) ) ) def _name_changed(self): """ 当绘图名发生变化时,更新绘图的标题 """ axe = self.figure.axes[0] axe.set_title(self.name) self.figure.canvas.draw() def _clear_button_fired(self): """ 清除按钮的事件处理 """ self.selected_items = [] self.update() def _figure_default(self): """ figure属性的缺省值,直接创建一个Figure对象 """ figure = Figure() figure.add_axes([0.1, 0.1, 0.85, 0.80]) #添加绘图区域,四周留有边距 return figure def _selected_items_changed(self): """ Y轴数据选择更新 """ self.update() def _selected_xaxis_changed(self): """ X轴数据选择更新 """ self.update() def update(self): """ 重新绘制所有的曲线 """ axe = self.figure.axes[0] axe.clear() try: xdata = self.data_source.data[self.selected_xaxis] except: return for field in self.selected_items: axe.plot(xdata, self.data_source.data[field], label=field) axe.set_xlabel(self.selected_xaxis) axe.set_title(self.name) axe.legend() self.figure.canvas.draw()
except: return 1. return cmp(getNum(x), getNum(y)) filenames_view = View( Group( #Item('directory', style = 'simple'), #Item('pattern', style = 'simple'), Item('filenames', style='custom', editor=ListStrEditor( selected='selected', operations=['insert', 'edit', 'move', 'delete', 'append'], auto_add=True, drag_move=True), height=-100, width=-300), Item('is_reversed', style='simple'), ), Item('from_directory_bttn', show_label=False), # statusbar = [ StatusItem( name = 'error')], resizable=True, ) def filenames_from_list(filenames): """A helper function. Returns a :class:`Filenames` object from a given filenames list
class FieldExplorer(HasTraits): scene = Instance(SceneModel, ()) wire = Instance(WireLoop) interact = Bool(False) ipl = Instance(tvtk.PlaneWidget, (), { 'resolution': 50, 'normal': [1., 0., 0.] }) #plane_src = Instance(tvtk.PlaneSource, ()) calc_B = Instance(tvtk.ProgrammableFilter, ()) glyph = Instance(tvtk.Glyph3D, (), {'scale_factor': 0.02}) scale_factor = DelegatesTo("glyph") lm = Instance(LUTManager, ()) traits_view = View(HSplit( Item("scene", style="custom", editor=SceneEditor(), show_label=False), VGroup(Item("wire", style="custom", show_label=False), Item("interact"), Item("scale_factor"), Item("lm")), ), resizable=True, width=700, height=600) def _interact_changed(self, i): self.ipl.interactor = self.scene.interactor self.ipl.place_widget() if i: self.ipl.on() else: self.ipl.off() def make_probe(self): src = self.ipl.poly_data_algorithm map = tvtk.PolyDataMapper(lookup_table=self.lm.lut) act = tvtk.Actor(mapper=map) calc_B = self.calc_B calc_B.input = src.output def execute(): print "calc fields!" output = calc_B.poly_data_output points = output.points.to_array().astype('d') nodes = self.wire.nodes.astype('d') vectors = calc_wire_B_field(nodes, points, self.wire.radius) output.point_data.vectors = vectors mag = np.sqrt((vectors**2).sum(axis=1)) map.scalar_range = (mag.min(), mag.max()) calc_B.set_execute_method(execute) cone = tvtk.ConeSource(height=0.05, radius=0.01, resolution=15) cone.update() glyph = self.glyph glyph.input_connection = calc_B.output_port glyph.source = cone.output glyph.scale_mode = 'scale_by_vector' glyph.color_mode = 'color_by_vector' map.input_connection = glyph.output_port self.scene.add_actor(act) def on_update(self): self.calc_B.modified() self.scene.render() def _wire_changed(self, anew): anew.on_trait_change(self.on_update, "update") self.scene.add_actor(anew.actor)
class DlsAnalyzer(BaseFileAnalyzer): """ DlsAnalyzer is used to analyze multiple dls files. First you must define a function that returns x value for the data analyzed. A default function :attr:'get_x_value' returns just index value. This function must have two erguments as an input: index value and filename. It is up to you how the return value uses these inputs. For instance: >>> def get_x(fnames, index): ... return 100 + 0.1 * index Then create :class:`Filenames` instance (optional) >>> filenames = Filenames(directory = '../testdata', pattern = *.ASC) Now you cen create analyzer and do some analysis >>> fitter = create_dls_fitter('single_stretch_exp') >>> analyzer = DlsAnalyzer(filenames = filenames, ... fitter = fitter, ... get_x_value = get_x) >>> analyzer.log_name = 'analysis.rst' #specify logname to log results in reStructuredText format >>> analyzer.constants = (('s','n'),()) #set constant parameters in fitting process, >>> analyzer.x_name = 'position' #specify x data name When everything is set you can call process to fit all data. >>> analyzer.process() >>> analyzer.save_result('..testdata/output.npy') """ #: Filenames instance filenames = Instance(Filenames, ()) #: selected filename selected = DelegatesTo('filenames') #: data fitter object for data fitting fitter = Instance(DlsFitter) #: defines a list of constants tuple that are set in each fit run. See :meth:`process` constants = List(List(Str)) #: defines whethere fit plots are saved saves_fits = Bool(False) #: if defined it will generate a valif reStructuredText file log_name = Str #: actual log is written here log = Str #: fit results are storred here results = Instance(StructArrayData, ()) #: This function is used to get x value from index integer and filename string get_x_value = Function #: this specifies name of the x data of results x_name = Str('index') #: if this list is not empty it will be used to obtain x_values x_values = List(Float) view = View(Group(dls_analyzer_group, 'saves_fits', 'results'), Item('fitter', style='custom'), resizable=True) @on_trait_change('selected') def _open_dls(self, name): self.fitter.open_dls(name) self.fitter._plot() def _constants_default(self): return [['f', 's'], ['']] def _get_x_value_default(self): def get(fnames, index): return index return get def _selected_changed(self): self.process_selected() def process_selected(self): """Opens fname and fits data according to self.constants :param str fname: filename of asc data to be opened and fitted """ fname = self.selected self.fitter.open_dls(fname) print(self.constants) for constants in self.constants: try: self.fitter.fit(constants=constants) except: self.fitter.configure_traits() if self.saves_fits: path, fname = os.path.split(fname) path = os.path.join(path, 'fits') try: os.mkdir(path) except: pass fname = os.path.join(path, fname) imagename = fname + '.png' log.info('Plotting %s' % imagename) self.fitter.plotter.title = imagename self.fitter.plotter.savefig(imagename) result = self.fitter.function.get_parameters() self._process_result(result, self.selected, self.index) return result def _process_result(self, result, fname, index): result = (i for sub in result for i in sub) #flatten results list first try: self.results.data[index] = (self.x_values[index], ) + tuple(result) except: self.results.data[index] = (self.get_x_value( self.filenames.filenames, index), ) + tuple(result) self.results.data_updated = True @on_trait_change('filenames.filenames') def _init(self): array_names = [self.x_name] for name in self.fitter.function.pnames: array_names.append(name) array_names.append(name + '_err') dtype = np.dtype(list(zip(array_names, ['float'] * len(array_names)))) self.results = StructArrayData( data=np.zeros(len(self.filenames), dtype=dtype)) #self.results_err = StructArrayData(data = np.zeros(len(self.filenames), dtype = dtype)) self.results.data_updated = True #self.results_err.data_updated = True #self.log = '===========\nFit results\n===========\n\n' return True def save_results(self, fname): """Saves results to disk :param str fname: output filename """ np.save(fname, self.results.data) if self.log_name: self.log = '===========\nFit results\n===========\n\n' for fname in self.filenames.filenames: imagename = fname + '.png' self.log += '.. image:: %s\n' % os.path.basename(imagename) with open(self.log_name, 'w') as f: f.write(self.log)
class RangeEditorDemo(HasTraits): """ This class specifies the details of the RangeEditor demo. """ # Define a trait for each of four variants small_int_range = Range(1, 16) medium_int_range = Range(1, 25) large_int_range = Range(1, 150) float_range = Range(0.0, 150.0) # RangeEditor display for narrow integer Range traits (< 17 wide): int_range_group1 = Group(Item('small_int_range', style='simple', label='Simple'), Item('_'), Item('small_int_range', style='custom', label='Custom'), Item('_'), Item('small_int_range', style='text', label='Text'), Item('_'), Item('small_int_range', style='readonly', label='ReadOnly'), label="Small Int") # RangeEditor display for medium-width integer Range traits (17 to 100): int_range_group2 = Group(Item('medium_int_range', style='simple', label='Simple'), Item('_'), Item('medium_int_range', style='custom', label='Custom'), Item('_'), Item('medium_int_range', style='text', label='Text'), Item('_'), Item('medium_int_range', style='readonly', label='ReadOnly'), label="Medium Int") # RangeEditor display for wide integer Range traits (> 100): int_range_group3 = Group(Item('large_int_range', style='simple', label='Simple'), Item('_'), Item('large_int_range', style='custom', label='Custom'), Item('_'), Item('large_int_range', style='text', label='Text'), Item('_'), Item('large_int_range', style='readonly', label='ReadOnly'), label="Large Int") # RangeEditor display for float Range traits: float_range_group = Group(Item('float_range', style='simple', label='Simple'), Item('_'), Item('float_range', style='custom', label='Custom'), Item('_'), Item('float_range', style='text', label='Text'), Item('_'), Item('float_range', style='readonly', label='ReadOnly'), label="Float") # The view includes one group per data type. These will be displayed # on separate tabbed panels. view1 = View(int_range_group1, int_range_group2, int_range_group3, float_range_group, title='RangeEditor', buttons=['OK'])
class SetStep(HasTraits): _viewer = Instance(Viewer) _source = Instance(FileSource) seq_start = Int(0) seq_stop = Int(-1) seq_step = Int(1) seq_t0 = Float seq_t1 = Float seq_dt = Float seq_n_step = Int _step_editor = RangeEditor(low_name='step_low', high_name='step_high', label_width=28, auto_set=True, mode='slider') step = None step_low = Int step_high = Int _time_editor = RangeEditor(low_name='time_low', high_name='time_high', label_width=28, auto_set=True, mode='slider') time = None time_low = Float time_high = Float file_changed = Bool(False) is_adjust = False traits_view = View( Item('step', defined_when='step is not None', editor=_step_editor), Item('time', defined_when='time is not None', editor=_time_editor), HGroup(Heading('steps:'), Item('seq_start', label='start'), Item('seq_stop', label='stop'), Item('seq_step', label='step'), Heading('times:'), Item('seq_t0', label='t0'), Item('seq_t1', label='t1'), Item('seq_dt', label='dt'), Item('seq_n_step', label='n_step')), ) def __source_changed(self, old, new): steps = self._source.steps if len(steps): self.add_trait('step', Int(0)) self.step_low, self.step_high = steps[0], steps[-1] times = self._source.times if len(times): self.add_trait('time', Float(0.0)) self.time_low, self.time_high = times[0], times[-1] def _step_changed(self, old, new): if new == old: return if not self.is_adjust: step, time = self._source.get_step_time(step=new) self.is_adjust = True self.step = step self.time = time self.is_adjust = False self._viewer.set_source_filename(self._source.filename) def _time_changed(self, old, new): if new == old: return if not self.is_adjust: step, time = self._source.get_step_time(time=new) self.is_adjust = True self.step = step self.time = time self.is_adjust = False self._viewer.set_source_filename(self._source.filename) def _file_changed_changed(self, old, new): if new == True: steps = self._source.steps if len(steps): self.step_low, self.step_high = steps[0], steps[-1] times = self._source.times if len(times): self.time_low, self.time_high = times[0], times[-1] self.file_changed = False @on_trait_change('step_high, time_high') def init_seq_selection(self, name, new): self.seq_t0 = self.time_low self.seq_t1 = self.time_high self.seq_n_step = self.step_high - self.step_low + 1 self.seq_dt = (self.seq_t1 - self.seq_t0) / self.seq_n_step self.seq_start = self.step_low self.seq_stop = self.step_high + 1 if name == 'time_high': self.on_trait_change(self.init_seq_selection, 'time_high', remove=True) def _seq_n_step_changed(self, old, new): if new == old: return self.seq_dt = (self.seq_t1 - self.seq_t0) / self.seq_n_step def _seq_dt_changed(self, old, new): if new == old: return if self.seq_dt == 0.0: return n_step = int(round((self.seq_t1 - self.seq_t0) / self.seq_dt)) self.seq_n_step = max(1, n_step)
class BuiltinImage(Source): # The version of this class. Used for persistence. __version__ = 0 # Flag to set the image data type. source = Enum('ellipsoid','gaussian','grid','mandelbrot','noise', 'sinusoid','rt_analytic', desc='which image data source to be used') # Define the trait 'data_source' whose value must be an instance of # type ImageAlgorithm data_source = Instance(tvtk.ImageAlgorithm, allow_none=False, record=True) # Information about what this object can produce. output_info = PipelineInfo(datasets=['image_data'], attribute_types=['any'], attributes=['any']) # Create the UI for the traits. view = View(Group(Item(name='source'), Item(name='data_source', style='custom', resizable=True), label='Image Source', show_labels=False), resizable=True) ######################################## # Private traits. # A dictionary that maps the source names to instances of the # image data objects. _source_dict = Dict(Str, Instance(tvtk.ImageAlgorithm, allow_none=False)) ###################################################################### # `object` interface ###################################################################### def __init__(self, **traits): # Call parent class' init. super(BuiltinImage, self).__init__(**traits) # Initialize the source to the default mode's instance from # the dictionary if needed. if 'source' not in traits: self._source_changed(self.source) def __set_pure_state__(self, state): self.source = state.source super(BuiltinImage, self).__set_pure_state__(state) ###################################################################### # Non-public methods. ###################################################################### def _source_changed(self, value): """This method is invoked (automatically) when the `function` trait is changed. """ self.data_source = self._source_dict[self.source] def _data_source_changed(self, old, new): """This method is invoked (automatically) when the image data source is changed .""" self.outputs = [self.data_source.output] if old is not None: old.on_trait_change(self.render, remove=True) new.on_trait_change(self.render) def __source_dict_default(self): """The default _source_dict trait.""" sd = { 'ellipsoid':tvtk.ImageEllipsoidSource(), 'gaussian':tvtk.ImageGaussianSource(), 'grid':tvtk.ImageGridSource(), 'mandelbrot':tvtk.ImageMandelbrotSource(), 'noise':tvtk.ImageNoiseSource(), 'sinusoid':tvtk.ImageSinusoidSource(), } if hasattr(tvtk, 'RTAnalyticSource'): sd['rt_analytic'] = tvtk.RTAnalyticSource() else: sd['rt_analytic'] = tvtk.ImageNoiseSource() return sd
class DataSourceWizardView(DataSourceWizard): #---------------------------------------------------------------------- # Private traits #---------------------------------------------------------------------- _top_label = Str('Describe your data') _info_text = Str('Array size do not match') _array_label = Str('Available arrays') _data_type_text = Str("What does your data represents?" ) _lines_text = Str("Connect the points with lines" ) _scalar_data_text = Str("Array giving the value of the scalars") _optional_scalar_data_text = Str("Associate scalars with the data points") _connectivity_text = Str("Array giving the triangles") _vector_data_text = Str("Associate vector components") _position_text = Property(depends_on="position_type_") _position_text_dict = {'explicit': 'Coordinnates of the data points:', 'orthogonal grid': 'Position of the layers along each axis:', } def _get__position_text(self): return self._position_text_dict.get(self.position_type_, "") _shown_help_text = Str _data_sources_wrappers = Property(depends_on='data_sources') def _get__data_sources_wrappers(self): return [ ArrayColumnWrapper(name=name, shape=repr(self.data_sources[name].shape)) for name in self._data_sources_names ] # A traits pointing to the object, to play well with traitsUI _self = Instance(DataSourceWizard) _suitable_traits_view = Property(depends_on="data_type_") def _get__suitable_traits_view(self): return "_%s_data_view" % self.data_type_ ui = Any(False) _preview_button = Button(label='Preview structure') def __preview_button_fired(self): if self.ui: self.build_data_source() self.preview() _ok_button = Button(label='OK') def __ok_button_fired(self): if self.ui: self.ui.dispose() self.build_data_source() _cancel_button = Button(label='Cancel') def __cancel_button_fired(self): if self.ui: self.ui.dispose() _is_ok = Bool _is_not_ok = Bool def _anytrait_changed(self): """ Validates if the OK button is enabled. """ if self.ui: self._is_ok = self.check_arrays() self._is_not_ok = not self._is_ok _preview_window = Instance(PreviewWindow, ()) _info_image = Instance(ImageResource, ImageLibrary.image_resource('@std:alert16',)) #---------------------------------------------------------------------- # TraitsUI views #---------------------------------------------------------------------- _coordinates_group = \ HGroup( Item('position_x', label='x', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok')), Item('position_y', label='y', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok')), Item('position_z', label='z', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok')), ) _position_group = \ Group( Item('position_type'), Group( Item('_position_text', style='readonly', resizable=False, show_label=False), _coordinates_group, visible_when='not position_type_=="image data"', ), Group( Item('grid_shape_source_', label='Grid shape', editor=EnumEditor( name='_grid_shape_source_labels', invalid='_is_not_ok')), HGroup( spring, Item('grid_shape', style='custom', editor=ArrayEditor(width=-60), show_label=False), enabled_when='grid_shape_source==""', ), visible_when='position_type_=="image data"', ), label='Position of the data points', show_border=True, show_labels=False, ), _connectivity_group = \ Group( HGroup( Item('_connectivity_text', style='readonly', resizable=False), spring, Item('connectivity_triangles', editor=EnumEditor(name='_data_sources_names'), show_label=False, ), show_labels=False, ), label='Connectivity information', show_border=True, show_labels=False, enabled_when='position_type_=="explicit"', ), _scalar_data_group = \ Group( Item('_scalar_data_text', style='readonly', resizable=False, show_label=False), HGroup( spring, Item('scalar_data', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok')), show_labels=False, ), label='Scalar value', show_border=True, show_labels=False, ) _optional_scalar_data_group = \ Group( HGroup( 'has_scalar_data', Item('_optional_scalar_data_text', resizable=False, style='readonly'), show_labels=False, ), Item('_scalar_data_text', style='readonly', resizable=False, enabled_when='has_scalar_data', show_label=False), HGroup( spring, Item('scalar_data', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok'), enabled_when='has_scalar_data'), show_labels=False, ), label='Scalar data', show_border=True, show_labels=False, ), _vector_data_group = \ VGroup( HGroup( Item('vector_u', label='u', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok')), Item('vector_v', label='v', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok')), Item('vector_w', label='w', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok')), ), label='Vector data', show_border=True, ), _optional_vector_data_group = \ VGroup( HGroup( Item('has_vector_data', show_label=False), Item('_vector_data_text', style='readonly', resizable=False, show_label=False), ), HGroup( Item('vector_u', label='u', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok')), Item('vector_v', label='v', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok')), Item('vector_w', label='w', editor=EnumEditor(name='_data_sources_names', invalid='_is_not_ok')), enabled_when='has_vector_data', ), label='Vector data', show_border=True, ), _array_view = \ View( Item('_array_label', editor=TitleEditor(), show_label=False), Group( Item('_data_sources_wrappers', editor=TabularEditor( adapter = ArrayColumnAdapter(), ), ), show_border=True, show_labels=False )) _questions_view = View( Item('_top_label', editor=TitleEditor(), show_label=False), HGroup( Item('_data_type_text', style='readonly', resizable=False), spring, 'data_type', spring, show_border=True, show_labels=False, ), HGroup( Item('_self', style='custom', editor=InstanceEditor( view_name='_suitable_traits_view'), ), Group( # FIXME: Giving up on context sensitive help # because of lack of time. #Group( # Item('_shown_help_text', editor=HTMLEditor(), # width=300, # label='Help', # ), # show_labels=False, # label='Help', #), #Group( Item('_preview_button', enabled_when='_is_ok'), Item('_preview_window', style='custom', label='Preview structure'), show_labels=False, #label='Preview structure', #), #layout='tabbed', #dock='tab', ), show_labels=False, show_border=True, ), ) _point_data_view = \ View(Group( Group(_coordinates_group, label='Position of the data points', show_border=True, ), HGroup( 'lines', Item('_lines_text', style='readonly', resizable=False), label='Lines', show_labels=False, show_border=True, ), _optional_scalar_data_group, _optional_vector_data_group, # XXX: hack to have more vertical space Label('\n'), Label('\n'), Label('\n'), )) _surface_data_view = \ View(Group( _position_group, _connectivity_group, _optional_scalar_data_group, _optional_vector_data_group, )) _vector_data_view = \ View(Group( _vector_data_group, _position_group, _optional_scalar_data_group, )) _volumetric_data_view = \ View(Group( _scalar_data_group, _position_group, _optional_vector_data_group, )) _wizard_view = View( Group( HGroup( Item('_self', style='custom', show_label=False, editor=InstanceEditor(view='_array_view'), width=0.17, ), '_', Item('_self', style='custom', show_label=False, editor=InstanceEditor(view='_questions_view'), ), ), HGroup( Item('_info_image', editor=ImageEditor(), visible_when="_is_not_ok"), Item('_info_text', style='readonly', resizable=False, visible_when="_is_not_ok"), spring, '_cancel_button', Item('_ok_button', enabled_when='_is_ok'), show_labels=False, ), ), title='Import arrays', resizable=True, ) #---------------------------------------------------------------------- # Public interface #---------------------------------------------------------------------- def __init__(self, **traits): DataSourceFactory.__init__(self, **traits) self._self = self def view_wizard(self): """ Pops up the view of the wizard, and keeps the reference it to be able to close it. """ # FIXME: Workaround for traits bug in enabled_when self.position_type_ self.data_type_ self._suitable_traits_view self.grid_shape_source self._is_ok self.ui = self.edit_traits(view='_wizard_view') def preview(self): """ Display a preview of the data structure in the preview window. """ self._preview_window.clear() self._preview_window.add_source(self.data_source) data = lambda name: self.data_sources[name] g = Glyph() g.glyph.glyph_source.glyph_source = \ g.glyph.glyph_source.glyph_list[0] g.glyph.scale_mode = 'data_scaling_off' if not (self.has_vector_data or self.data_type_ == 'vector'): g.glyph.glyph_source.glyph_source.glyph_type = 'cross' g.actor.property.representation = 'points' g.actor.property.point_size = 3. self._preview_window.add_module(g) if not self.data_type_ in ('point', 'vector') or self.lines: s = Surface() s.actor.property.opacity = 0.3 self._preview_window.add_module(s) if not self.data_type_ == 'point': self._preview_window.add_filter(ExtractEdges()) s = Surface() s.actor.property.opacity = 0.2 self._preview_window.add_module(s)
class GlyphSource(Component): # The version of this class. Used for persistence. __version__ = 1 # Glyph position. This can be one of ['head', 'tail', 'center'], # and indicates the position of the glyph with respect to the # input point data. Please note that this will work correctly # only if you do not mess with the source glyph's basic size. For # example if you use a ConeSource and set its height != 1, then the # 'head' and 'tail' options will not work correctly. glyph_position = Trait('center', TraitPrefixList(['head', 'tail', 'center']), desc='position of glyph w.r.t. data point') # The Source to use for the glyph. This is chosen from # `self._glyph_list` or `self.glyph_dict`. glyph_source = Instance(tvtk.Object, allow_none=False, record=True) # A dict of glyphs to use. glyph_dict = Dict(desc='the glyph sources to select from', record=False) # A list of predefined glyph sources that can be used. glyph_list = Property(List(tvtk.Object), record=False) ######################################## # Private traits. # The transformation to use to place glyph appropriately. _trfm = Instance(tvtk.TransformFilter, args=()) # Used for optimization. _updating = Bool(False) ######################################## # View related traits. view = View(Group( Group(Item(name='glyph_position')), Group(Item( name='glyph_source', style='custom', resizable=True, editor=InstanceEditor(name='glyph_list'), ), label='Glyph Source', show_labels=False)), resizable=True) ###################################################################### # `Base` interface ###################################################################### def __get_pure_state__(self): d = super(GlyphSource, self).__get_pure_state__() for attr in ('_updating', 'glyph_list'): d.pop(attr, None) return d def __set_pure_state__(self, state): if 'glyph_dict' in state: # Set their state. set_state(self, state, first=['glyph_dict'], ignore=['*']) ignore = ['glyph_dict'] else: # Set the dict state using the persisted list. gd = self.glyph_dict gl = self.glyph_list handle_children_state(gl, state.glyph_list) for g, gs in zip(gl, state.glyph_list): name = camel2enthought(g.__class__.__name__) if name not in gd: gd[name] = g # Set the glyph source's state. set_state(g, gs) ignore = ['glyph_list'] g_name = state.glyph_source.__metadata__['class_name'] name = camel2enthought(g_name) # Set the correct glyph_source. self.glyph_source = self.glyph_dict[name] set_state(self, state, ignore=ignore) ###################################################################### # `Component` interface ###################################################################### def setup_pipeline(self): """Override this method so that it *creates* the tvtk pipeline. This method is invoked when the object is initialized via `__init__`. Note that at the time this method is called, the tvtk data pipeline will *not* yet be setup. So upstream data will not be available. The idea is that you simply create the basic objects and setup those parts of the pipeline not dependent on upstream sources and filters. You should also set the `actors` attribute up at this point. """ self._trfm.transform = tvtk.Transform() # Setup the glyphs. self.glyph_source = self.glyph_dict['glyph_source2d'] def update_pipeline(self): """Override this method so that it *updates* the tvtk pipeline when data upstream is known to have changed. This method is invoked (automatically) when any of the inputs sends a `pipeline_changed` event. """ self._glyph_position_changed(self.glyph_position) self.pipeline_changed = True def update_data(self): """Override this method so that it flushes the vtk pipeline if that is necessary. This method is invoked (automatically) when any of the inputs sends a `data_changed` event. """ self.data_changed = True def render(self): if not self._updating: super(GlyphSource, self).render() ###################################################################### # Non-public methods. ###################################################################### def _glyph_source_changed(self, value): if self._updating == True: return gd = self.glyph_dict value_cls = camel2enthought(value.__class__.__name__) if value not in gd.values(): gd[value_cls] = value # Now change the glyph's source trait. self._updating = True recorder = self.recorder if recorder is not None: name = recorder.get_script_id(self) lhs = '%s.glyph_source' % name rhs = '%s.glyph_dict[%r]' % (name, value_cls) recorder.record('%s = %s' % (lhs, rhs)) name = value.__class__.__name__ if name == 'GlyphSource2D': self.outputs = [value.output] else: self._trfm.input = value.output self.outputs = [self._trfm.output] value.on_trait_change(self.render) self._updating = False # Now update the glyph position since the transformation might # be different. self._glyph_position_changed(self.glyph_position) def _glyph_position_changed(self, value): if self._updating == True: return self._updating = True tr = self._trfm.transform tr.identity() g = self.glyph_source name = g.__class__.__name__ # Compute transformation factor if name == 'CubeSource': tr_factor = g.x_length / 2.0 elif name == 'CylinderSource': tr_factor = -g.height / 2.0 elif name == 'ConeSource': tr_factor = g.height / 2.0 elif name == 'SphereSource': tr_factor = g.radius else: tr_factor = 1. # Translate the glyph if value == 'tail': if name == 'GlyphSource2D': g.center = 0.5, 0.0, 0.0 elif name == 'ArrowSource': pass elif name == 'CylinderSource': g.center = 0, tr_factor, 0.0 elif hasattr(g, 'center'): g.center = tr_factor, 0.0, 0.0 elif value == 'head': if name == 'GlyphSource2D': g.center = -0.5, 0.0, 0.0 elif name == 'ArrowSource': tr.translate(-1, 0, 0) elif name == 'CylinderSource': g.center = 0, -tr_factor, 0.0 else: g.center = -tr_factor, 0.0, 0.0 else: if name == 'ArrowSource': tr.translate(-0.5, 0, 0) elif name != 'Axes': g.center = 0.0, 0.0, 0.0 if name == 'CylinderSource': tr.rotate_z(90) self._updating = False self.render() def _get_glyph_list(self): # Return the glyph list as per the original order in earlier # implementation. order = [ 'glyph_source2d', 'arrow_source', 'cone_source', 'cylinder_source', 'sphere_source', 'cube_source', 'axes' ] gd = self.glyph_dict for key in gd: if key not in order: order.append(key) return [gd[key] for key in order] def _glyph_dict_default(self): g = { 'glyph_source2d': tvtk.GlyphSource2D(glyph_type='arrow', filled=False), 'arrow_source': tvtk.ArrowSource(), 'cone_source': tvtk.ConeSource(height=1.0, radius=0.2, resolution=15), 'cylinder_source': tvtk.CylinderSource(height=1.0, radius=0.15, resolution=10), 'sphere_source': tvtk.SphereSource(), 'cube_source': tvtk.CubeSource(), 'axes': tvtk.Axes(symmetric=1) } return g