class Demo(HasTraits): list1 = List(Int) list2 = List(Float) list3 = List(Str, maxlen=3) list4 = List(Enum('red', 'green', 'blue', 2, 3)) list5 = List(Range(low=0.0, high=10.0)) # 'low' and 'high' are used to demonstrate lists containing dynamic ranges. low = Float(0.0) high = Float(1.0) list6 = List(Range(low=-1.0, high='high')) list7 = List(Range(low='low', high='high')) pop1 = Button("Pop from first list") sort1 = Button("Sort first list") # This will be str(self.list1). list1str = Property(Str, depends_on='list1') traits_view = \ View( HGroup( # This VGroup forms the column of CSVListEditor examples. VGroup( Item('list1', label="List(Int)", editor=CSVListEditor(ignore_trailing_sep=False), tooltip='options: ignore_trailing_sep=False'), Item('list1', label="List(Int)", style='readonly', editor=CSVListEditor()), Item('list2', label="List(Float)", editor=CSVListEditor(enter_set=True, auto_set=False), tooltip='options: enter_set=True, auto_set=False'), Item('list3', label="List(Str, maxlen=3)", editor=CSVListEditor()), Item('list4', label="List(Enum('red', 'green', 'blue', 2, 3))", editor=CSVListEditor(sep=None), tooltip='options: sep=None'), Item('list5', label="List(Range(low=0.0, high=10.0))", editor=CSVListEditor()), Item('list6', label="List(Range(low=-1.0, high='high'))", editor=CSVListEditor()), Item('list7', label="List(Range(low='low', high='high'))", editor=CSVListEditor()), springy=True, ), # This VGroup forms the right column; it will display the # Python str representation of the lists. VGroup( UItem('list1str', editor=TextEditor(), enabled_when='False', width=240), UItem('list1str', editor=TextEditor(), enabled_when='False', width=240), UItem('list2', editor=TextEditor(), enabled_when='False', width=240), UItem('list3', editor=TextEditor(), enabled_when='False', width=240), UItem('list4', editor=TextEditor(), enabled_when='False', width=240), UItem('list5', editor=TextEditor(), enabled_when='False', width=240), UItem('list6', editor=TextEditor(), enabled_when='False', width=240), UItem('list7', editor=TextEditor(), enabled_when='False', width=240), ), ), '_', HGroup('low', 'high', spring, UItem('pop1'), UItem('sort1')), Heading("Notes"), Label("Hover over a list to see which editor options are set, " "if any."), Label("The editor of the first list, List(Int), uses " "ignore_trailing_sep=False, so a trailing comma is " "an error."), Label("The second list is a read-only view of the first list."), Label("The editor of the List(Float) example has enter_set=True " "and auto_set=False; press Enter to validate."), Label("The List(Str) example will accept at most 3 elements."), Label("The editor of the List(Enum(...)) example uses sep=None, " "i.e. whitespace acts as a separator."), Label("The last two List(Range(...)) examples take one or both " "of their limits from the Low and High fields below."), width=720, title="CSVListEditor Demonstration", ) def _list1_default(self): return [1, 4, 0, 10] def _get_list1str(self): return str(self.list1) def _pop1_fired(self): if len(self.list1) > 0: x = self.list1.pop() print(x) def _sort1_fired(self): self.list1.sort()
class QSSQLObject(__QS_Object__): """基于关系数据库的对象""" DBType = Enum("MySQL", "SQL Server", "Oracle", arg_type="SingleOption", label="数据库类型", order=0) DBName = Str("Scorpion", arg_type="String", label="数据库名", order=1) IPAddr = Str("127.0.0.1", arg_type="String", label="IP地址", order=2) Port = Range(low=0, high=65535, value=3306, arg_type="Integer", label="端口", order=3) User = Str("root", arg_type="String", label="用户名", order=4) Pwd = Password("shuntai11", arg_type="String", label="密码", order=5) TablePrefix = Str("", arg_type="String", label="表名前缀", order=6) CharSet = Enum("utf8", "gbk", "gb2312", "gb18030", "cp936", "big5", arg_type="SingleOption", label="字符集", order=7) Connector = Enum("default", "cx_Oracle", "pymssql", "mysql.connector", "pyodbc", arg_type="SingleOption", label="连接器", order=8) DSN = Str("", arg_type="String", label="数据源", order=9) def __init__(self, sys_args={}, config_file=None, **kwargs): self._Connection = None return super().__init__(sys_args=sys_args, config_file=config_file, **kwargs) def __getstate__(self): state = self.__dict__.copy() state["_Connection"] = (True if self.isAvailable() else False) return state def __setstate__(self, state): super().__setstate__(state) if self._Connection: self._connect() else: self._Connection = None def connect(self): if (self.Connector=="cx_Oracle") or ((self.Connector=="default") and (self.DBType=="Oracle")): try: import cx_Oracle self._Connection = cx_Oracle.connect(self.User, self.Pwd, cx_Oracle.makedsn(self.IPAddr, str(self.Port), self.DBName)) except Exception as e: if self.Connector!="default": raise e elif (self.Connector=="pymssql") or ((self.Connector=="default") and (self.DBType=="SQL Server")): try: import pymssql self._Connection = pymssql.connect(server=self.IPAddr, port=str(self.Port), user=self.User, password=self.Pwd, database=self.DBName, charset=self.CharSet) except Exception as e: if self.Connector!="default": raise e elif (self.Connector=="mysql.connector") or ((self.Connector=="default") and (self.DBType=="MySQL")): try: import mysql.connector self._Connection = mysql.connector.connect(host=self.IPAddr, port=str(self.Port), user=self.User, password=self.Pwd, database=self.DBName, charset=self.CharSet, autocommit=True) except Exception as e: if self.Connector!="default": raise e else: if self.Connector not in ("default", "pyodbc"): self._Connection = None raise __QS_Error__("不支持该连接器(connector) : "+self.Connector) else: import pyodbc if self.DSN: self._Connection = pyodbc.connect("DSN=%s;PWD=%s" % (self.DSN, self.Pwd)) else: self._Connection = pyodbc.connect("DRIVER={%s};DATABASE=%s;SERVER=%s;UID=%s;PWD=%s" % (self.DBType, self.DBName, self.IPAddr, self.User, self.Pwd)) self.Connector = "pyodbc" return 0 def disconnect(self): if self._Connection is not None: try: self._Connection.close() except Exception as e: raise e finally: self._Connection = None return 0 def isAvailable(self): return (self._Connection is not None) def cursor(self, sql_str=None): if self._Connection is None: raise __QS_Error__("%s尚未连接!" % self.__doc__) Cursor = self._Connection.cursor() if sql_str is None: return Cursor Cursor.execute(sql_str) return Cursor def fetchall(self, sql_str): Cursor = self.cursor(sql_str=sql_str) Data = Cursor.fetchall() Cursor.close() return Data def execute(self, sql_str): Cursor = self._Connection.cursor() Cursor.execute(sql_str) self._Connection.commit() Cursor.close() return 0 def addIndex(self, index_name, table_name, fields, index_type="BTREE"): SQLStr = "CREATE INDEX "+index_name+" USING "+index_type+" ON "+self.TablePrefix+table_name+"("+", ".join(fields)+")" return self.execute(SQLStr)
class Pattern(HasTraits): graph = Instance(Graph, (), transient=True) cx = Float(transient=True) cy = Float(transient=True) target_radius = Range(0.0, 3.0, 1) # beam_radius = Float(1, enter_set=True, auto_set=False) show_overlap = Bool(False) beam_radius = Range(0.0, 3.0, 1) path = Str name = Property(depends_on='path') # image_width = 640 # image_height = 480 xbounds = (-3, 3) ybounds = (-3, 3) # pxpermm = None velocity = Float(1) calculated_transit_time = Float niterations = Range(1, 200) # canceled = Event # def map_pt(self, x, y): # # return self.pxpermm * x + self.image_width / 2, self.pxpermm * y + self.image_height / 2 # def close(self, isok): # self.canceled = True # return True @property def kind(self): return self.__class__.__name__ def _get_name(self): if not self.path: return 'New Pattern' return os.path.basename(self.path).split('.')[0] def _anytrait_changed(self, name, new): if name != 'calculated_transit_time': self.replot() self.calculate_transit_time() def calculate_transit_time(self): n = self.niterations c = -self._get_path_length() b = self.velocity acceleration = 1 a = 0.5 * acceleration # 0 = -c+ b * t + 0.5 * a * t ** 2 t1 = -b + (b**2 - 4 * a * c) / (2.0 * a) t2 = -b - (b**2 - 4 * a * c) / (2.0 * a) self.calculated_transit_time = (max(t1, t2) + self._get_delay()) * n # self.calculated_transit_time = ((self._get_path_length() / # max(self.velocity, 0.001)) + self._get_delay()) * self.niterations def _get_path_length(self): return 0 def _get_delay(self): return 0 def _beam_radius_changed(self): oo = self.graph.plots[0].plots['plot0'][0].overlays[1] oo.beam_radius = self.beam_radius self.replot() def _show_overlap_changed(self): oo = self.graph.plots[0].plots['plot0'][0].overlays[1] oo.visible = self.show_overlap oo.request_redraw() def _target_radius_changed(self): self.graph.plots[0].plots['plot0'][0].overlays[ 0].target_radius = self.target_radius # def set_mapping(self, px): # self.pxpermm = px / 10.0 # # # def set_image(self, data, graph=None): # ''' # px,py pixels per cm x and y # ''' # if graph is None: # graph = self.graph # # # p = graph.plots[0].plots['plot0'][0] # # for ui in p.underlays: # # if isinstance(ui, ImageUnderlay): # # ui.image.load(img) # # break # # else: # if isinstance(data, str): # image = Image() # image.load(data) # data = image.get_array() # else: # data = data.as_numpy_array() # data = data.copy() # data = flipud(data) # # # mmx = px / 10.0 * (self.xbounds[1] - self.xbounds[0]) # # mmy = py / 10.0 * (self.ybounds[1] - self.ybounds[0]) # # # # w = 640 # # h = 480 # # cb = [w / 2 - mmx, w / 2 + mmx, h / 2 - mmy, h / 2 + mmy] # # cb = [h / 2 - mmy, h / 2 + mmy, w / 2 - mmx, w / 2 + mmx ] # # # graph.plots[0].data.set_data('imagedata', data) # graph.plots[0].img_plot('imagedata') # # # io = ImageUnderlay(component=p, image=image, crop_rect=(640 / 2, 480 / 2, mmx, mmy)) # # # # p.underlays.append(io) # # graph.redraw() def pattern_generator_factory(self, **kw): raise NotImplementedError def replot(self): self.plot() def plot(self): pgen_out = self.pattern_generator_factory() data_out = array([pt for pt in pgen_out]) xs, ys = transpose(data_out) # if self.pxpermm is not None: # xs, ys = self.map_pt(xs, ys) self.graph.set_data(xs) self.graph.set_data(ys, axis=1) return data_out[-1][0], data_out[-1][1] def points_factory(self): gen_out = self.pattern_generator_factory() return [pt for pt in gen_out] def graph_view(self): v = View(Item( 'graph', style='custom', show_label=False, ), handler=self.handler_klass, title=self.name) return v # def _get_crop_bounds(self): # px = self.pxpermm # # mmx = px / 10.0 * 1 / (self.xbounds[1] - self.xbounds[0]) # # mmy = py / 10.0 * 1 / (self.ybounds[1] - self.ybounds[0]) # windx = (self.xbounds[1] - self.xbounds[0]) # mmx = windx * px / 2 # # windy = (self.ybounds[1] - self.ybounds[0]) # mmy = windy * px / 2 # # w = self.image_width # h = self.image_height # # cbx = [w / 2 - mmx, w / 2 + mmx ] # cby = [h / 2 - mmy, h / 2 + mmy] # # return cbx, cby def clear_graph(self): graph = self.graph graph.set_data([], series=1, axis=0) graph.set_data([], series=1, axis=1) graph.set_data([], series=2, axis=0) graph.set_data([], series=2, axis=1) def reset_graph(self, **kw): self.graph = self._graph_factory(**kw) def _graph_factory(self, with_image=False): g = Graph(window_height=250, window_width=300, container_dict=dict(padding=0)) g.new_plot(bounds=[250, 250], resizable='', padding=[30, 0, 0, 30]) cx = self.cx cy = self.cy cbx = self.xbounds cby = self.ybounds tr = self.target_radius # if with_image: # px = self.pxpermm #px is in mm # cbx, cby = self._get_crop_bounds() # #g.set_axis_traits(tick_label_formatter=lambda x: '{:0.2f}'.format((x - w / 2) / px)) # #g.set_axis_traits(tick_label_formatter=lambda x: '{:0.2f}'.format((x - h / 2) / px), axis='y') # # bx, by = g.plots[0].bounds # g.plots[0].x_axis.mapper = LinearMapper(high_pos=bx, # range=DataRange1D(low_setting=self.xbounds[0], # high_setting=self.xbounds[1])) # g.plots[0].y_axis.mapper = LinearMapper(high_pos=by, # range=DataRange1D(low_setting=self.ybounds[0], # high_setting=self.ybounds[1])) # cx += self.image_width / 2 # cy += self.image_height / 2 # tr *= px g.set_x_limits(*cbx) g.set_y_limits(*cby) lp, _plot = g.new_series() t = TargetOverlay(component=lp, cx=cx, cy=cy, target_radius=tr) lp.overlays.append(t) overlap_overlay = OverlapOverlay(component=lp, visible=self.show_overlap) lp.overlays.append(overlap_overlay) g.new_series(type='scatter', marker='circle') g.new_series(type='line', color='red') return g def _graph_default(self): return self._graph_factory() # p = '/Users/ross/Desktop/foo2.tiff' # # i = Image()#width=640, height=480) # i.load(p) # # self.set_image(i, px, px, graph=g) # from chaco.image_data import ImageData # image = ImageData.fromfile(p) # # print image._data # crop(i.source_frame, 0, 0, 300, 300) # self.pattern.graph.plots[0].plots['plot0'][0].overlays.append(ImageUnderlay(image=i)) # self.pattern.graph.plots[0].plots[0].underlays.append(ImageUnderlay(image=i)) # io = ImageUnderlay(component=lp, image=i, visible=False) # lp.overlays.append(io) def maker_group(self): return Group(self.get_parameter_group(), Item('niterations'), HGroup( Item('velocity'), Item('calculated_transit_time', label='Time (s)', style='readonly', format_str='%0.1f')), Item('target_radius'), Item('show_overlap'), Item('beam_radius', enabled_when='show_overlap'), show_border=True, label='Pattern') def maker_view(self): v = View( HGroup( self.maker_group(), Item( 'graph', # resizable=False, show_label=False, style='custom'), ), # buttons=['OK', 'Cancel'], resizable=True) return v def traits_view(self): v = View(self.maker_group(), buttons=['OK', 'Cancel'], title=self.name, resizable=True) return v def get_parameter_group(self): raise NotImplementedError
class Visualization(HasTraits): meridional = Range(1, 30, 6) transverse = Range(0, 30, 11) scene_3D1 = Instance(MlabSceneModel, ()) scene_3D2 = Instance(MlabSceneModel, ()) figure = Instance(Figure, ()) colors_list = Enum(*COLORS) polynomial_degree = Int(4) residual_threshold = Int(20) def __init__(self): # Do not forget to call the parent's __init__ HasTraits.__init__(self) self.update_3dplot() self.update_plot() @on_trait_change('colors_list, polynomial_degree, residual_threshold') def update_3dplot(self): color = self.colors_list with open(os.path.join(base_path1, 'measurements.pkl'), 'rb') as f: measurements1 = cPickle.load(f) with open(os.path.join(base_path2, 'measurements.pkl'), 'rb') as f: measurements2 = cPickle.load(f) vc1 = VignettingCalibration.readMeasurements( base_path1, polynomial_degree=self.polynomial_degree, residual_threshold=self.residual_threshold) vc2 = VignettingCalibration.readMeasurements( base_path2, polynomial_degree=self.polynomial_degree, residual_threshold=self.residual_threshold) x1, y1, z1 = zip(*[a for a in measurements1 if a[0] is not None]) x2, y2, z2 = zip(*[a for a in measurements2 if a[0] is not None]) for scene, (x, y, z), vc in zip([self.scene_3D1, self.scene_3D2], ((x1, y1, z1), (x2, y2, z2)), (vc1, vc2)): mlab.clf(figure=scene.mayavi_scene) zs = np.array(z)[..., COLOR_INDICES[color]] zs = 255 * zs / zs.max() mlab.points3d(x, y, zs, mode='sphere', scale_mode='scalar', scale_factor=20, color=(1, 1, 1), figure=scene.mayavi_scene) surface = np.ones((1200, 1600)) corrected = raw2RGB(255 / vc.applyVignetting(surface)) mlab.surf(corrected[COLOR_INDICES[color]].T, extent=(0, 1600, 0, 1200, 0, 255), opacity=0.5) mlab.outline(color=(0, 0, 0), extent=(0, 1600, 0, 1200, 0, 255), figure=scene.mayavi_scene) @on_trait_change('colors_list') def update_plot(self): color = self.colors_list with open(os.path.join(base_path1, 'spec.pkl'), 'rb') as f: spec1 = cPickle.load(f) with open(os.path.join(base_path2, 'spec.pkl'), 'rb') as f: spec2 = cPickle.load(f) self.figure.clear() axes = self.figure.add_subplot(111) axes.plot(spec1[0], spec1[1], label='camera1') axes.plot(spec2[0], spec2[1], label='camera1') axes.legend() canvas = self.figure.canvas if canvas is not None: canvas.draw() # the layout of the dialog created view = View( VGroup( VGroup( #HGroup( #Item('figure', editor=MPLFigureEditor(), #show_label=False), #label = 'Spectograms', #show_border = True #), HGroup( HGroup(Item('scene_3D1', editor=SceneEditor(scene_class=MayaviScene), height=200, width=200, show_label=False), label='3D', show_border=True), HGroup(Item('scene_3D2', editor=SceneEditor(scene_class=MayaviScene), height=200, width=200, show_label=False), label='3D', show_border=True), ), '_', HGroup(Item('colors_list', style='simple'), Item('polynomial_degree', style='simple'), Item('residual_threshold', style='simple'), Spring()), )))
class BaseXYPlot(AbstractPlotRenderer): """ Base class for simple X-vs-Y plots that consist of a single index data array and a single value data array. Subclasses handle the actual rendering, but this base class takes care of most of making sure events are wired up between mappers and data or screen space changes, etc. """ #------------------------------------------------------------------------ # Data-related traits #------------------------------------------------------------------------ # The data source to use for the index coordinate. index = Instance(ArrayDataSource) # The data source to use as value points. value = Instance(AbstractDataSource) # Screen mapper for index data. index_mapper = Instance(AbstractMapper) # Screen mapper for value data value_mapper = Instance(AbstractMapper) # Convenience properties that correspond to either index_mapper or # value_mapper, depending on the orientation of the plot. # Corresponds to either **index_mapper** or **value_mapper**, depending on # the orientation of the plot. x_mapper = Property # Corresponds to either **value_mapper** or **index_mapper**, depending on # the orientation of the plot. y_mapper = Property # Convenience property for accessing the index data range. index_range = Property # Convenience property for accessing the value data range. value_range = Property # The type of hit-testing that is appropriate for this renderer. # # * 'line': Computes Euclidean distance to the line between the # nearest adjacent points. # * 'point': Checks for adjacency to a marker or point. hittest_type = Enum("point", "line") #------------------------------------------------------------------------ # Appearance-related traits #------------------------------------------------------------------------ # The orientation of the index axis. orientation = Enum("h", "v") # Overall alpha value of the image. Ranges from 0.0 for transparent to 1.0 alpha = Range(0.0, 1.0, 1.0) #------------------------------------------------------------------------ # Convenience readonly properties for common annotations #------------------------------------------------------------------------ # Read-only property for horizontal grid. hgrid = Property # Read-only property for vertical grid. vgrid = Property # Read-only property for x-axis. x_axis = Property # Read-only property for y-axis. y_axis = Property # Read-only property for labels. labels = Property #------------------------------------------------------------------------ # Other public traits #------------------------------------------------------------------------ # Does the plot use downsampling? # This is not used right now. It needs an implementation of robust, fast # downsampling, which does not exist yet. use_downsampling = Bool(False) # Does the plot use a spatial subdivision structure for fast hit-testing? # This makes data updates slower, but makes hit-tests extremely fast. use_subdivision = Bool(False) # Overrides the default background color trait in PlotComponent. bgcolor = "transparent" # This just turns on a simple drawing of the X and Y axes... not a long # term solution, but good for testing. # Defines the origin axis color, for testing. origin_axis_color = black_color_trait # Defines a the origin axis width, for testing. origin_axis_width = Float(1.0) # Defines the origin axis visibility, for testing. origin_axis_visible = Bool(False) #------------------------------------------------------------------------ # Private traits #------------------------------------------------------------------------ # Are the cache traits valid? If False, new ones need to be compute. _cache_valid = Bool(False) # Cached array of (x,y) data-space points; regardless of self.orientation, # these points are always stored as (index_pt, value_pt). _cached_data_pts = Array # Cached array of (x,y) screen-space points. _cached_screen_pts = Array # Does **_cached_screen_pts** contain the screen-space coordinates # of the points currently in **_cached_data_pts**? _screen_cache_valid = Bool(False) # Reference to a spatial subdivision acceleration structure. _subdivision = Any #------------------------------------------------------------------------ # Abstract methods that subclasses must implement #------------------------------------------------------------------------ def _render(self, gc, points): """ Abstract method for rendering points. Parameters ---------- gc : graphics context Target for drawing the points points : List of Nx2 arrays Screen-space points to render """ raise NotImplementedError def _gather_points(self): """ Abstract method to collect data points that are within the range of the plot, and cache them. """ raise NotImplementedError def _downsample(self): """ Abstract method that gives the renderer a chance to downsample in screen space. """ # By default, this just does a mapscreen and returns the result raise NotImplementedError #------------------------------------------------------------------------ # Concrete methods below #------------------------------------------------------------------------ def __init__(self, **kwtraits): # Handling the setting/initialization of these traits manually because # they should be initialized in a certain order. kwargs_tmp = {"trait_change_notify": False} for trait_name in ("index", "value", "index_mapper", "value_mapper"): if trait_name in kwtraits: kwargs_tmp[trait_name] = kwtraits.pop(trait_name) self.set(**kwargs_tmp) AbstractPlotRenderer.__init__(self, **kwtraits) if self.index is not None: self.index.on_trait_change(self._either_data_changed, "data_changed") self.index.on_trait_change(self._either_metadata_changed, "metadata_changed") if self.index_mapper: self.index_mapper.on_trait_change(self._mapper_updated_handler, "updated") if self.value is not None: self.value.on_trait_change(self._either_data_changed, "data_changed") self.value.on_trait_change(self._either_metadata_changed, "metadata_changed") if self.value_mapper: self.value_mapper.on_trait_change(self._mapper_updated_handler, "updated") # If we are not resizable, we will not get a bounds update upon layout, # so we have to manually update our mappers if self.resizable == "": self._update_mappers() return def hittest(self, screen_pt, threshold=7.0, return_distance=False): """ Performs proximity testing between a given screen point and the plot. Parameters ---------- screen_pt : (x,y) A point to test. threshold : integer Optional maximum screen space distance (pixels) between *screen_pt* and the plot. return_distance : Boolean If True, also return the distance. Returns ------- If self.hittest_type is 'point', then this method returns the screen coordinates of the closest point on the plot as a tuple (x,y) If self.hittest_type is 'line', then this method returns the screen endpoints of the line segment closest to *screen_pt*, as ((x1,y1), (x2,y2)) If *screen_pt* does not fall within *threshold* of the plot, then this method returns None. If return_distance is True, return the (x, y, d), where d is the distance between the distance between the input point and the closest point (x, y), in screen coordinates. """ if self.hittest_type == "point": tmp = self.get_closest_point(screen_pt, threshold) elif self.hittest_type == "line": tmp = self.get_closest_line(screen_pt, threshold) else: raise ValueError("Unknown hittest type '%s'" % self.hittest_type) if tmp is not None: if return_distance: return tmp else: return tmp[:-1] else: return None def get_closest_point(self, screen_pt, threshold=7.0): """ Tests for proximity in screen-space. This method checks only data points, not the line segments connecting them; to do the latter use get_closest_line() instead. Parameters ---------- screen_pt : (x,y) A point to test. threshold : integer Optional maximum screen space distance (pixels) between *screen_pt* and the plot. If 0.0, then no threshold tests are performed, and the nearest point is returned. Returns ------- (x, y, distance) of a datapoint nearest to *screen_pt*. If no data points are within *threshold* of *screen_pt*, returns None. """ ndx = self.map_index(screen_pt, threshold) if ndx is not None: x = self.x_mapper.map_screen(self.index.get_data()[ndx]) y = self.y_mapper.map_screen(self.value.get_data()[ndx]) return (x, y, sqrt((x - screen_pt[0])**2 + (y - screen_pt[1])**2)) else: return None def get_closest_line(self, screen_pt, threshold=7.0): """ Tests for proximity in screen-space against lines connecting the points in this plot's dataset. Parameters ---------- screen_pt : (x,y) A point to test. threshold : integer Optional maximum screen space distance (pixels) between the line and the plot. If 0.0, then the method returns the closest line regardless of distance from the plot. Returns ------- (x1, y1, x2, y2, dist) of the endpoints of the line segment closest to *screen_pt*. The *dist* element is the perpendicular distance from *screen_pt* to the line. If there is only a single point in the renderer's data, then the method returns the same point twice. If no data points are within *threshold* of *screen_pt*, returns None. """ ndx = self.map_index(screen_pt, threshold=0.0) if ndx is None: return None index_data = self.index.get_data() value_data = self.value.get_data() x = self.x_mapper.map_screen(index_data[ndx]) y = self.y_mapper.map_screen(value_data[ndx]) # We need to find another index so we have two points; in the # even that we only have 1 point, just return that point. datalen = len(index_data) if datalen == 1: dist = (x, y, sqrt((x - screen_pt[0])**2 + (y - screen_pt[1])**2)) if (threshold == 0.0) or (dist <= threshold): return (x, y, x, y, dist) else: return None else: if (ndx == 0) or (screen_pt[0] >= x): ndx2 = ndx + 1 elif (ndx == datalen - 1) or (screen_pt[0] <= x): ndx2 = ndx - 1 x2 = self.x_mapper.map_screen(index_data[ndx2]) y2 = self.y_mapper.map_screen(value_data[ndx2]) dist = point_line_distance(screen_pt, (x, y), (x2, y2)) if (threshold == 0.0) or (dist <= threshold): return (x, y, x2, y2, dist) else: return None #------------------------------------------------------------------------ # AbstractPlotRenderer interface #------------------------------------------------------------------------ def map_screen(self, data_array): """ Maps an array of data points into screen space and returns it as an array. Implements the AbstractPlotRenderer interface. """ # data_array is Nx2 array if len(data_array) == 0: return [] x_ary, y_ary = transpose(data_array) sx = self.index_mapper.map_screen(x_ary) sy = self.value_mapper.map_screen(y_ary) if self.orientation == "h": return transpose(array((sx, sy))) else: return transpose(array((sy, sx))) def map_data(self, screen_pt, all_values=False): """ Maps a screen space point into the "index" space of the plot. Implements the AbstractPlotRenderer interface. If *all_values* is True, returns an array of (index, value) tuples; otherwise, it returns only the index values. """ x, y = screen_pt if self.orientation == 'v': x, y = y, x if all_values: return array( (self.index_mapper.map_data(x), self.value_mapper.map_data(y))) else: return self.index_mapper.map_data(x) def map_index(self, screen_pt, threshold=2.0, outside_returns_none=True, index_only=False): """ Maps a screen space point to an index into the plot's index array(s). Implements the AbstractPlotRenderer interface. Parameters ---------- screen_pt : Screen space point threshold : float Maximum distance from screen space point to plot data point. A value of 0.0 means no threshold (any distance will do). outside_returns_none : bool If True, a screen space point outside the data range returns None. Otherwise, it returns either 0 (outside the lower range) or the last index (outside the upper range) index_only : bool If True, the threshold is measured on the distance between the index values, otherwise as Euclidean distance between the (x,y) coordinates. """ data_pt = self.map_data(screen_pt) if ((data_pt < self.index_mapper.range.low) or (data_pt > self.index_mapper.range.high)) and outside_returns_none: return None index_data = self.index.get_data() value_data = self.value.get_data() if len(value_data) == 0 or len(index_data) == 0: return None try: # find the closest point to data_pt in index_data ndx = reverse_map_1d(index_data, data_pt, self.index.sort_order) except IndexError: # if reverse_map raises this exception, it means that data_pt is # outside the range of values in index_data. if outside_returns_none: return None else: if data_pt < index_data[0]: return 0 else: return len(index_data) - 1 if threshold == 0.0: # Don't do any threshold testing return ndx x = index_data[ndx] y = value_data[ndx] if isnan(x) or isnan(y): return None # transform x,y in a 1x2 array, which is the preferred format of # map_screen. this makes it robust against differences in # the map_screen methods of logmapper and linearmapper # when passed a scalar xy = array([[x, y]]) sx, sy = self.map_screen(xy).T if index_only and (threshold == 0.0 or screen_pt[0] - sx < threshold): return ndx elif ((screen_pt[0] - sx)**2 + (screen_pt[1] - sy)**2 < threshold * threshold): return ndx else: return None def get_screen_points(self): """Returns the currently visible screen-space points. Intended for use with overlays. """ self._gather_points() if self.use_downsampling: # The BaseXYPlot implementation of _downsample doesn't actually # do any downsampling. return self._downsample() else: return self.map_screen(self._cached_data_pts) #------------------------------------------------------------------------ # PlotComponent interface #------------------------------------------------------------------------ def _draw_plot(self, gc, view_bounds=None, mode="normal"): """ Draws the 'plot' layer. """ self._draw_component(gc, view_bounds, mode) return def _draw_component(self, gc, view_bounds=None, mode="normal"): # This method should be folded into self._draw_plot(), but is here for # backwards compatibilty with non-draw-order stuff. pts = self.get_screen_points() self._render(gc, pts) return def _draw_default_axes(self, gc): if not self.origin_axis_visible: return with gc: gc.set_stroke_color(self.origin_axis_color_) gc.set_line_width(self.origin_axis_width) gc.set_line_dash(None) for range in (self.index_mapper.range, self.value_mapper.range): if (range.low < 0) and (range.high > 0): if range == self.index_mapper.range: dual = self.value_mapper.range data_pts = array([[0.0, dual.low], [0.0, dual.high]]) else: dual = self.index_mapper.range data_pts = array([[dual.low, 0.0], [dual.high, 0.0]]) start, end = self.map_screen(data_pts) start = around(start) end = around(end) gc.move_to(int(start[0]), int(start[1])) gc.line_to(int(end[0]), int(end[1])) gc.stroke_path() return def _post_load(self): super(BaseXYPlot, self)._post_load() self._update_mappers() self.invalidate_draw() self._cache_valid = False self._screen_cache_valid = False return def _update_subdivision(self): return #------------------------------------------------------------------------ # Properties #------------------------------------------------------------------------ def _get_index_range(self): return self.index_mapper.range def _set_index_range(self, val): self.index_mapper.range = val def _get_value_range(self): return self.value_mapper.range def _set_value_range(self, val): self.value_mapper.range = val def _get_x_mapper(self): if self.orientation == "h": return self.index_mapper else: return self.value_mapper def _get_y_mapper(self): if self.orientation == "h": return self.value_mapper else: return self.index_mapper def _get_hgrid(self): for obj in self.underlays + self.overlays: if isinstance(obj, PlotGrid) and obj.orientation == "horizontal": return obj else: return None def _get_vgrid(self): for obj in self.underlays + self.overlays: if isinstance(obj, PlotGrid) and obj.orientation == "vertical": return obj else: return None def _get_x_axis(self): for obj in self.underlays + self.overlays: if isinstance(obj, PlotAxis) and obj.orientation in ("bottom", "top"): return obj else: return None def _get_y_axis(self): for obj in self.underlays + self.overlays: if isinstance(obj, PlotAxis) and obj.orientation in ("left", "right"): return obj else: return None def _get_labels(self): labels = [] for obj in self.underlays + self.overlays: if isinstance(obj, PlotLabel): labels.append(obj) return labels #------------------------------------------------------------------------ # Event handlers #------------------------------------------------------------------------ def _update_mappers(self): x_mapper = self.index_mapper y_mapper = self.value_mapper if self.orientation == "v": x_mapper, y_mapper = y_mapper, x_mapper x = self.x x2 = self.x2 y = self.y y2 = self.y2 if "left" in self.origin: x_mapper.screen_bounds = (x, x2) else: x_mapper.screen_bounds = (x2, x) if "bottom" in self.origin: y_mapper.screen_bounds = (y, y2) else: y_mapper.screen_bounds = (y2, y) self.invalidate_draw() self._cache_valid = False self._screen_cache_valid = False def _bounds_changed(self, old, new): super(BaseXYPlot, self)._bounds_changed(old, new) self._update_mappers() def _bounds_items_changed(self, event): super(BaseXYPlot, self)._bounds_items_changed(event) self._update_mappers() def _position_changed(self): self._update_mappers() def _position_items_changed(self): self._update_mappers() def _orientation_changed(self): self._update_mappers() def _index_changed(self, old, new): if old is not None: old.on_trait_change(self._either_data_changed, "data_changed", remove=True) old.on_trait_change(self._either_metadata_changed, "metadata_changed", remove=True) if new is not None: new.on_trait_change(self._either_data_changed, "data_changed") new.on_trait_change(self._either_metadata_changed, "metadata_changed") self._either_data_changed() return def _either_data_changed(self): self.invalidate_draw() self._cache_valid = False self._screen_cache_valid = False self.request_redraw() return def _either_metadata_changed(self): # By default, don't respond to metadata change events. pass def _value_changed(self, old, new): if old is not None: old.on_trait_change(self._either_data_changed, "data_changed", remove=True) old.on_trait_change(self._either_metadata_changed, "metadata_changed", remove=True) if new is not None: new.on_trait_change(self._either_data_changed, "data_changed") new.on_trait_change(self._either_metadata_changed, "metadata_changed") self._either_data_changed() return def _origin_changed(self, old, new): # origin switch from left to right or vice versa? if old.split()[1] != new.split()[1]: xm = self.x_mapper xm.low_pos, xm.high_pos = xm.high_pos, xm.low_pos # origin switch from top to bottom or vice versa? if old.split()[0] != new.split()[0]: ym = self.y_mapper ym.low_pos, ym.high_pos = ym.high_pos, ym.low_pos self.invalidate_draw() self._screen_cache_valid = False return def _index_mapper_changed(self, old, new): self._either_mapper_changed(self, "index_mapper", old, new) if self.orientation == "h": self.trait_property_changed("x_mapper", old, new) else: self.trait_property_changed("y_mapper", old, new) return def _value_mapper_changed(self, old, new): self._either_mapper_changed(self, "value_mapper", old, new) if self.orientation == "h": self.trait_property_changed("y_mapper", old, new) else: self.trait_property_changed("x_mapper", old, new) return def _either_mapper_changed(self, obj, name, old, new): if old is not None: old.on_trait_change(self._mapper_updated_handler, "updated", remove=True) if new is not None: new.on_trait_change(self._mapper_updated_handler, "updated") self.invalidate_draw() self._screen_cache_valid = False return def _mapper_updated_handler(self): self._cache_valid = False self._screen_cache_valid = False self.invalidate_draw() self.request_redraw() return def _visible_changed(self, old, new): if new: self._layout_needed = True def _bgcolor_changed(self): self.invalidate_draw() def _use_subdivision_changed(self, old, new): if new: self._set_up_subdivision() return #------------------------------------------------------------------------ # Persistence #------------------------------------------------------------------------ def __getstate__(self): state = super(BaseXYPlot, self).__getstate__() for key in [ '_cache_valid', '_cached_data_pts', '_screen_cache_valid', '_cached_screen_pts' ]: if key in state: del state[key] return state def __setstate__(self, state): super(BaseXYPlot, self).__setstate__(state) if self.index is not None: self.index.on_trait_change(self._either_data_changed, "data_changed") if self.value is not None: self.value.on_trait_change(self._either_data_changed, "data_changed") self.invalidate_draw() self._cache_valid = False self._screen_cache_valid = False self._update_mappers() return
class DataFramePlotManagerExporter(HasStrictTraits): """ Exporter of a DataFramePlotManager content to various formats. """ df_plotter = Instance("pybleau.app.api.DataFramePlotManager") export_format = Enum([IMG_FORMAT, PPT_FORMAT, VEGA_FORMAT]) #: Whether to, and how to, export the data behind the plots export_data = Enum(values="_export_data_options") export_each_plot_data = Bool #: The ways to export the data depend on the target format _export_data_options = Property(List(Str), depends_on="export_format") data_filename = Str(EXTERNAL_DATA_FNAME) data_format = Enum([".csv", ".xlsx", ".h5"]) #: Whether to skip plots whose visible flag is off skip_hidden = Bool(True) #: Target directory target_dir = Directory target_file = File view_klass = Any(View) interactive = Bool(True) # Image format specific parameters ---------------------------------------- image_format = Enum(["PNG", "JPG", "BMP", "TIFF", "EPS"]) image_dpi = Range(low=50, high=100, value=72) filename_from_title = Bool(True) # Image format specific parameters ---------------------------------------- overwrite_file_if_exists = Bool(False) json_index = Range(low=0, high=4, value=2) _many_plots = Bool # PPT format specific parameters ------------------------------------------ presentation_title = Str("New Presentation") presentation_subtitle = Str def traits_view(self): is_ppt = "export_format == '{}'".format(PPT_FORMAT) is_vega = "export_format == '{}'".format(VEGA_FORMAT) is_img = "export_format=='{}'".format(IMG_FORMAT) num_plots = len(self.df_plotter.contained_plots) inline_warning = "Warning: exporting data inline can use a lot of " \ "file storage because the same data is stored inside"\ " each of the {} plots.".format(num_plots) inline_visible_when = "_many_plots and export_data=='{}'".format( EXPORT_INLINE) dat_file_visible_when = "export_data in ['{}', '{}']".format( EXPORT_YES, EXPORT_SEPARATE) view = self.view_klass( VGroup( HGroup(Spring(), Item("export_format"), Spring()), VGroup(Item('target_dir', visible_when=is_img), HGroup(Item('target_file', editor=FileEditor(dialog_style='save')), Item("overwrite_file_if_exists", label="Overwrite report file if exists?"), visible_when=is_ppt + " or " + is_vega), Item("overwrite_file_if_exists", label="Overwrite image files if exist?", visible_when=is_img), Item('skip_hidden', label="Skip hidden plots"), label="General Parameters", show_border=True), VGroup( HGroup( Item('export_data', label="Export data?"), Item("export_each_plot_data", label="Export each plot's data?", visible_when="export_data=='{}'".format( EXPORT_YES)) # noqa ), HGroup(Item("data_filename"), Item("data_format"), visible_when=dat_file_visible_when), HGroup(Label(inline_warning), visible_when=inline_visible_when), label="Data Parameters", show_border=True), VGroup(HGroup( Item('image_format'), Item('image_dpi'), ), Item("filename_from_title"), visible_when=is_img, label="Image Parameters", show_border=True), VGroup(Item("presentation_title"), Item("presentation_subtitle"), label="Powerpoint Parameters", show_border=True, visible_when=is_ppt), ), buttons=OKCancelButtons, title="Export plots", resizable=True, width=600) return view # Public interface -------------------------------------------------------- def export(self): """ Launch view to select parameters and export content to file. """ if self.interactive: ui = self.edit_traits(kind="livemodal") if not self.interactive or ui.result: msg = f"Exporting plot content to {self.export_format}." logger.log(ACTION_LEVEL, msg) to_meth = getattr(self, METHOD_MAP[self.export_format]) try: to_meth() if self.interactive: target = self.target_dir if self.export_format != \ PPT_FORMAT else self.target_file open_file(target) except Exception as e: msg = f"Failed to export the plot list. Error was {e}." logger.exception(msg) if self.interactive: error(None, msg) def to_folder(self, **kwargs): """ Export all plots as separate images files PNG, JPEG, .... """ if not isdir(self.target_dir): os.makedirs(self.target_dir) plot_list = self.df_plotter.contained_plots if self.export_data == EXPORT_YES: if not self.export_each_plot_data: self._export_data_source_to_file(target=self.target_dir) else: fname = self.data_filename + self.data_format data_path = join(self.target_dir, fname) self._export_plot_data_to_file(plot_list, data_path=data_path) if self.filename_from_title: filename_patt = "{i}_{title}.{ext}" else: filename_patt = "plot_{i}.{ext}" for i, desc in enumerate(plot_list): if self.skip_hidden and not desc.visible: continue if self.filename_from_title: title = string2filename(desc.plot_title) else: title = "" filename = filename_patt.format(i=i, ext=self.image_format, title=title) filepath = join(self.target_dir, filename) if isfile(filepath) and not self.overwrite_file_if_exists: msg = "Target image file path specified already exists" \ ": {}. Move the file or select the 'overwrite' checkbox." msg = msg.format(filepath) logger.exception(msg) raise IOError(msg) save_plot_to_file(desc.plot, filepath=filepath, dpi=self.image_dpi, **kwargs) def to_pptx(self, **kwargs): """ Export all plots as a PPTX presentation with a plot per slide. """ # Protect imports so pptx remains an optional import from pybleau.reporting.pptx_utils import image_to_slide, Presentation,\ title_slide target_dir = dirname(self.target_file) data_fname = self.data_filename + self.data_format target_data_file = join(target_dir, data_fname) if not isdir(target_dir): os.makedirs(target_dir) plot_list = self.df_plotter.contained_plots if self.export_data == EXPORT_YES: if not self.export_each_plot_data: self._export_data_source_to_file(target=target_data_file) else: self._export_plot_data_to_file(plot_list, data_path=target_data_file) if splitext(self.target_file)[1] != ".pptx": self.target_file += ".pptx" if isfile(self.target_file) and not self.overwrite_file_if_exists: msg = "Target description file path specified already exists" \ ": {}. Move the file or select the 'overwrite' checkbox." msg = msg.format(self.target_file) logger.exception(msg) raise IOError(msg) presentation = Presentation() title_slide(presentation, title_text=self.presentation_title, sub_title_text=self.presentation_subtitle) img_path = mkstemp()[1] + ".png" for i, desc in enumerate(plot_list): if self.skip_hidden and not desc.visible: continue save_plot_to_file(desc.plot, filepath=img_path, dpi=self.image_dpi, **kwargs) title = "plot_{}".format(i) image_to_slide(presentation, img_path=img_path, slide_title=title) os.remove(img_path) presentation.save(self.target_file) def to_vega(self): """ Export plot content to dict (option. file) in Vega-Lite format. This is useful to recreate this plotter's content, in a separate process, using any plotting library. See the pybleau.reporting for tools leveraging this export. Returns ------- dict: Description of the content of the data plotter. TODO: Add filter information for each plot. """ content = {CONTENT_KEY: [], DATASETS_KEY: {}} if splitext(self.target_file)[1] != ".json": self.target_file += ".json" export_dir = abspath(dirname(self.target_file)) if not isdir(export_dir): os.makedirs(export_dir) if self.export_data != EXPORT_NO: df_data = {} if self.export_data == EXPORT_SEPARATE: df_data[DATA_FILE_KEY] = self.data_filename + self.data_format if self.data_format == ".h5": df_data[DATA_FILE_KEY_KEY] = DATA_KEY comp_info = dict(complib="blosc", complevel=9) df_data[DATA_FILE_COMP_KEY] = comp_info self._export_data_source_to_file(export_dir, **comp_info) else: self._export_data_source_to_file(export_dir) elif self.export_data == EXPORT_IN_FILE: df_data[DATA_KEY] = df_to_vega(self.df_plotter.data_source) df_data[IDX_NAME_KEY] = self.df_plotter.data_source.index.name content[DATASETS_KEY][DEFAULT_DATASET_NAME] = df_data for desc in self.df_plotter.contained_plots: if self.export_data == EXPORT_INLINE: plot_desc = chaco2vega(desc.plot_config, export_data="inline") elif self.export_data == EXPORT_IN_FILE: plot_desc = chaco2vega(desc.plot_config, export_data=DEFAULT_DATASET_NAME) else: plot_desc = chaco2vega(desc.plot_config, export_data=False) content[CONTENT_KEY].append(plot_desc) if self.target_file: if isfile(self.target_file) and not self.overwrite_file_if_exists: msg = f"Target description file path specified already " \ f"exists: {self.target_file}. Move the file or select " \ f"the 'overwrite' checkbox." logger.exception(msg) raise IOError(msg) json.dump(content, open(self.target_file, "w"), indent=self.json_index) return content # Private interface ------------------------------------------------------- def _export_data_source_to_file(self, target, key=DEFAULT_DATASET_NAME, **kwargs): """ Export the plotter's data source to file. Parameters ---------- target : str or pd.XlsxWtriter or pd.HDFStore Path to the file or folder or store object to write the source data to. key : str Tag for the dataset. Used only for Excel files (sheet name) or HDF5 (dataset node name). """ data_format = self.data_format if isinstance(target, string_types) and isfile(target): msg = "Target data file path specified already exists: {}. It " \ "will be overwritten!".format(target) logger.info(msg) elif isinstance(target, string_types) and isdir(target): target = join(target, self.data_filename + data_format) df = self.df_plotter.data_source if data_format == ".csv": df.to_csv(target) elif data_format == ".xlsx": df.to_excel(target, sheet_name=key, **kwargs) elif data_format == ".h5": df.to_hdf(target, key=key, **kwargs) else: msg = "Format {} not implemented. Please report this issue." msg = msg.format(data_format) logger.exception(msg) raise NotImplementedError(msg) return target def _export_plot_data_to_file(self, plot_list, data_path, **kwargs): """ Export the plots' PlotData to a file. Supported formats include zipped .csv, multi-tab .xlsx and multi-key HDF5. Parameters ---------- plot_list : list List of PlotDescriptor instances, containing the plot whose data need to be exported. data_path : str Path to the data file to be generated. """ data_format = self.data_format if not splitext(data_path)[1]: data_path += data_format if isfile(data_path): msg = "Target data path specified already exists: {}. It will be" \ " overwritten.".format(data_path) logger.warning(msg) data_dir = dirname(data_path) if data_format == ".xlsx": writer = pd.ExcelWriter(data_path) elif data_format == ".h5": writer = pd.HDFStore(data_path) try: if data_format == ".csv": data_path = join( data_dir, string2filename(DEFAULT_DATASET_NAME) + ".csv") self._export_data_source_to_file(target=data_path) elif data_format == ".xlsx": writer = pd.ExcelWriter(data_path) self._export_data_source_to_file(target=writer, key=DEFAULT_DATASET_NAME) else: self._export_data_source_to_file(target=writer, key=DEFAULT_DATASET_NAME) created_csv_files = [data_path] for i, desc in enumerate(plot_list): df_dict = plot_data2dataframes(desc) for name, df in df_dict.items(): key = "plot_{}_{}".format(i, name) if data_format == ".csv": target_fpath = join(data_dir, string2filename(key) + ".csv") df.to_csv(target_fpath) created_csv_files.append(target_fpath) elif data_format == ".xlsx": df.to_excel(writer, sheet_name=key, **kwargs) elif data_format == ".h5": df.to_hdf(data_path, key=key, **kwargs) else: msg = "Data format {} not implemented. Please report" \ " this issue.".format(data_format) logger.exception(msg) raise NotImplementedError(msg) finally: if data_format in [".xlsx", ".h5"]: writer.close() if data_format == ".csv" and len(created_csv_files) > 1: # zip up all csv files: data_path = join(data_dir, self.data_filename + ".zip") with ZipFile(data_path, "w") as f: for f_path in created_csv_files: f.write(f_path, basename(f_path)) for f_path in created_csv_files: os.remove(f_path) if self.interactive: msg = "Plot data is stored in: {}".format(data_path) information(None, msg) # Property getters/setters ------------------------------------------------ @cached_property def _get__export_data_options(self): if self.export_format in {IMG_FORMAT, PPT_FORMAT}: return [EXPORT_NO, EXPORT_YES] elif self.export_format == VEGA_FORMAT: return [EXPORT_NO, EXPORT_SEPARATE, EXPORT_IN_FILE, EXPORT_INLINE] else: msg = "List of export data options not set for format {}." msg = msg.format(self.export_format) logger.exception(msg) raise NotImplementedError(msg) # Traits initialization methods ------------------------------------------- def _target_file_default(self): return join(self.target_dir, DEFAULT_EXPORT_FILENAME) def _target_dir_default(self): return expanduser("~") def __many_plots_default(self): return len(self.df_plotter.contained_plots) >= 3
class ToolkitEditorFactory(EditorFactory): """ Editor factory for range editors. """ # ------------------------------------------------------------------------- # Trait definitions: # ------------------------------------------------------------------------- #: Number of columns when displayed as an enumeration cols = Range(1, 20) #: Is user input set on every keystroke? auto_set = Bool(True) #: Is user input set when the Enter key is pressed? enter_set = Bool(False) #: Label for the low end of the range low_label = Str() #: Label for the high end of the range high_label = Str() #: FIXME: This is supported only in the wx backend so far. #: The width of the low and high labels label_width = Int() #: The name of an [object.]trait that defines the low value for the range low_name = Str() #: The name of an [object.]trait that defines the high value for the range high_name = Str() #: Formatting string used to format value and labels format = Str("%s") #: Is the range for floating pointer numbers (vs. integers)? is_float = Bool(Undefined) #: Function to evaluate floats/ints when they are assigned to an object #: trait evaluate = Any() #: The object trait containing the function used to evaluate floats/ints evaluate_name = Str() #: Low end of range low = Property() #: High end of range high = Property() #: Display mode to use mode = Enum("auto", "slider", "xslider", "spinner", "enum", "text", "logslider") # ------------------------------------------------------------------------- # Traits view definition: # ------------------------------------------------------------------------- traits_view = View([ ["low", "high", "|[Range]"], ["low_label{Low}", "high_label{High}", "|[Range Labels]"], [ "auto_set{Set automatically}", "enter_set{Set on enter key pressed}", "is_float{Is floating point range}", "-[Options]>", ], ["cols", "|[Number of columns for integer custom style]<>"], ]) def init(self, handler=None): """ Performs any initialization needed after all constructor traits have been set. """ if handler is not None: if isinstance(handler, CTrait): handler = handler.handler if self.low_name == "": if isinstance(handler._low, CodeType): self.low = eval(handler._low) else: self.low = handler._low if self.high_name == "": if isinstance(handler._low, CodeType): self.high = eval(handler._high) else: self.high = handler._high else: if (self.low is None) and (self.low_name == ""): self.low = 0.0 if (self.high is None) and (self.high_name == ""): self.high = 1.0 def _get_low(self): return self._low def _set_low(self, low): old_low = self._low self._low = low = self._cast(low) if self.is_float is Undefined: self.is_float = isinstance(low, float) if (self.low_label == "") or (self.low_label == str(old_low)): self.low_label = str(low) def _get_high(self): return self._high def _set_high(self, high): old_high = self._high self._high = high = self._cast(high) if self.is_float is Undefined: self.is_float = isinstance(high, float) if (self.high_label == "") or (self.high_label == str(old_high)): self.high_label = str(high) def _cast(self, value): if not isinstance(value, str): return value try: return int(value) except ValueError: return float(value) # -- Private Methods ------------------------------------------------------ def _get_low_high(self, ui): """ Returns the low and high values used to determine the initial range. """ low, high = self.low, self.high if (low is None) and (self.low_name != ""): low = self.named_value(self.low_name, ui) if self.is_float is Undefined: self.is_float = isinstance(low, float) if (high is None) and (self.high_name != ""): high = self.named_value(self.high_name, ui) if self.is_float is Undefined: self.is_float = isinstance(high, float) if self.is_float is Undefined: self.is_float = True return (low, high, self.is_float) # ------------------------------------------------------------------------- # Property getters. # ------------------------------------------------------------------------- def _get_simple_editor_class(self): """ Returns the editor class to use for a simple style. The type of editor depends on the type and extent of the range being edited: * One end of range is unspecified: RangeTextEditor * **mode** is specified and not 'auto': editor corresponding to **mode** * Floating point range with extent > 100: LargeRangeSliderEditor * Integer range or floating point range with extent <= 100: SimpleSliderEditor * All other cases: SimpleSpinEditor """ low, high, is_float = self._low_value, self._high_value, self.is_float if (low is None) or (high is None): return toolkit_object("range_editor:RangeTextEditor") if (not is_float) and (abs(high - low) > 1000000000): return toolkit_object("range_editor:RangeTextEditor") if self.mode != "auto": return toolkit_object("range_editor:SimpleEditorMap")[self.mode] if is_float and (abs(high - low) > 100): return toolkit_object("range_editor:LargeRangeSliderEditor") if is_float or (abs(high - low) <= 100): return toolkit_object("range_editor:SimpleSliderEditor") return toolkit_object("range_editor:SimpleSpinEditor") def _get_custom_editor_class(self): """ Creates a custom style of range editor The type of editor depends on the type and extent of the range being edited: * One end of range is unspecified: RangeTextEditor * **mode** is specified and not 'auto': editor corresponding to **mode** * Floating point range: Same as "simple" style * Integer range with extent > 15: Same as "simple" style * Integer range with extent <= 15: CustomEnumEditor """ low, high, is_float = self._low_value, self._high_value, self.is_float if (low is None) or (high is None): return toolkit_object("range_editor:RangeTextEditor") if self.mode != "auto": return toolkit_object("range_editor:CustomEditorMap")[self.mode] if is_float or (abs(high - low) > 15): return self.simple_editor_class return toolkit_object("range_editor:CustomEnumEditor") def _get_text_editor_class(self): """Returns the editor class to use for a text style. """ return toolkit_object("range_editor:RangeTextEditor") # ------------------------------------------------------------------------- # 'Editor' factory methods: # ------------------------------------------------------------------------- def simple_editor(self, ui, object, name, description, parent): """ Generates an editor using the "simple" style. Overridden to set the values of the _low_value, _high_value and is_float traits. """ self._low_value, self._high_value, self.is_float = self._get_low_high( ui) return super(RangeEditor, self).simple_editor(ui, object, name, description, parent) def custom_editor(self, ui, object, name, description, parent): """ Generates an editor using the "custom" style. Overridden to set the values of the _low_value, _high_value and is_float traits. """ self._low_value, self._high_value, self.is_float = self._get_low_high( ui) return super(RangeEditor, self).custom_editor(ui, object, name, description, parent)
class XYScatterOptions(BasePlotterOptions): # update_needed = Event auto_refresh = Bool index_attr = Str('Ar40') value_attr = Str('Ar39') marker_size = Range(0.0, 10., 2.0) marker = MarkerTrait marker_color = ColorTrait attrs = Dict(DATABASE_ATTRS) index_error = Bool value_error = Bool index_end_caps = Bool value_end_caps = Bool index_nsigma = Enum(1, 2, 3) value_nsigma = Enum(1, 2, 3) index_time_units = Enum('h', 'm', 's', 'days') index_time_scalar = Property value_time_units = Enum('h', 'm', 's', 'days') value_time_scalar = Property fit = Enum([NULL_STR] + FIT_TYPES) datasource = Enum('Database', 'File') file_source_path = Str datasource_name = Property(depends_on='file_source_path') use_file_source = Property(depends_on='datasource') open_file_button = Button _parser = None def get_marker_dict(self): kw = dict(marker=self.marker, marker_size=self.marker_size, color=self.marker_color) return kw def get_parser(self): p = self._parser if p is None: p = CSVColumnParser() p.load(self.file_source_path) self._parser = p return p def _load_hook(self): if self.use_file_source: self._load_file_source() def _load_file_source(self): p = self.get_parser() keys = p.list_attributes() self.attrs = { ai: '{:02d}:{}'.format(i, ai) for i, ai in enumerate(keys) } self.index_attr = keys[0] self.value_attr = keys[1] def _datasource_changed(self, new): if new == 'Database': self.attrs = DATABASE_ATTRS def _get_dump_attrs(self): return [ 'index_attr', 'index_error', 'index_end_caps', 'index_nsigma', 'index_time_units', 'value_attr', 'value_error', 'value_end_caps', 'value_nsigma', 'value_time_units', 'fit', 'marker_color', 'marker', 'marker_size', 'auto_refresh', 'datasource', 'file_source_path' ] def _get_index_time_scalar(self): return TIME_SCALARS[self.index_time_units] def _get_value_time_scalar(self): return TIME_SCALARS[self.value_time_units] def _get_use_file_source(self): return self.datasource == 'File' def _get_datasource_name(self): p = '' if os.path.isfile(self.file_source_path): p = os.path.basename(self.file_source_path) return p def _open_file_button_fired(self): # dlg=FileDialog(action='open', default_directory=paths.data_dir) # dlg.open() # if dlg.path: # self.file_source_path=dlg.path # self._load_file_source() p = '/Users/ross/Sandbox/xy_scatter_test.csv' self.file_source_path = p self._load_file_source() @on_trait_change('index_+, value_+, marker+') def _refresh(self): if self.auto_refresh: # self.update_needed = True self.refresh_plot_needed = True def traits_view(self): index_grp = VGroup(HGroup( Item('index_attr', editor=EnumEditor(name='attrs'), label='X Attr.'), Item('index_time_units', label='Units', visible_when='index_attr=="timestamp"')), HGroup( Item('index_error', label='Show'), Item('index_end_caps', label='End Caps', enabled_when='index_error'), Item('index_nsigma', label='NSigma', enabled_when='index_error')), label='X Error', show_border=True) value_grp = VGroup(HGroup( Item('value_attr', editor=EnumEditor(name='attrs'), label='Y Attr.'), Item('value_time_units', label='Units', visible_when='value_attr=="timestamp"')), HGroup( Item('value_error', label='Show'), Item('value_end_caps', label='End Caps', enabled_when='value_error'), Item('value_nsigma', label='NSigma', enabled_when='value_error')), label='Y Error', show_border=True) marker_grp = VGroup(Item('marker'), Item('marker_size', label='Size'), Item('marker_color', label='Color'), show_border=True) datasource_grp = HGroup( Item('datasource'), UItem('datasource_name', style='readonly', visible_when='use_file_source'), icon_button_editor('open_file_button', 'document-open', visible_when='use_file_source"')) v = View(HGroup(Item('auto_refresh'), icon_button_editor('refresh_plot_needed', 'refresh')), datasource_grp, index_grp, value_grp, marker_grp, Item('fit'), resizable=True) return v
class ImageButton(Widget): """ An image and text-based control that can be used as a normal, radio or toolbar button. """ # Pens used to draw the 'selection' marker: _selectedPenDark = wx.Pen( wx.SystemSettings_GetColour(wx.SYS_COLOUR_3DSHADOW), 1, wx.SOLID) _selectedPenLight = wx.Pen( wx.SystemSettings_GetColour(wx.SYS_COLOUR_3DHIGHLIGHT), 1, wx.SOLID) #--------------------------------------------------------------------------- # Trait definitions: #--------------------------------------------------------------------------- # The image: image = Instance(ImageResource, allow_none=True) # The (optional) label: label = Str # Extra padding to add to both the left and right sides: width_padding = Range(0, 31, 7) # Extra padding to add to both the top and bottom sides: height_padding = Range(0, 31, 5) # Presentation style: style = Enum('button', 'radio', 'toolbar', 'checkbox') # Orientation of the text relative to the image: orientation = Enum('vertical', 'horizontal') # Is the control selected ('radio' or 'checkbox' style)? selected = false # Fired when a 'button' or 'toolbar' style control is clicked: clicked = Event #--------------------------------------------------------------------------- # Initializes the object: #--------------------------------------------------------------------------- def __init__(self, parent, **traits): """ Creates a new image control. """ self._image = None super(ImageButton, self).__init__(**traits) # Calculate the size of the button: idx = idy = tdx = tdy = 0 if self._image is not None: idx = self._image.GetWidth() idy = self._image.GetHeight() if self.label != '': dc = wx.ScreenDC() dc.SetFont(wx.NORMAL_FONT) tdx, tdy = dc.GetTextExtent(self.label) wp2 = self.width_padding + 2 hp2 = self.height_padding + 2 if self.orientation == 'horizontal': self._ix = wp2 spacing = (idx > 0) * (tdx > 0) * 4 self._tx = self._ix + idx + spacing dx = idx + tdx + spacing dy = max(idy, tdy) self._iy = hp2 + ((dy - idy) / 2) self._ty = hp2 + ((dy - tdy) / 2) else: self._iy = hp2 spacing = (idy > 0) * (tdy > 0) * 2 self._ty = self._iy + idy + spacing dx = max(idx, tdx) dy = idy + tdy + spacing self._ix = wp2 + ((dx - idx) / 2) self._tx = wp2 + ((dx - tdx) / 2) # Create the toolkit-specific control: self._dx = dx + wp2 + wp2 self._dy = dy + hp2 + hp2 self.control = wx.Window(parent, -1, size=wx.Size(self._dx, self._dy)) self.control._owner = self self._mouse_over = self._button_down = False # Set up mouse event handlers: wx.EVT_ENTER_WINDOW(self.control, self._on_enter_window) wx.EVT_LEAVE_WINDOW(self.control, self._on_leave_window) wx.EVT_LEFT_DOWN(self.control, self._on_left_down) wx.EVT_LEFT_UP(self.control, self._on_left_up) wx.EVT_PAINT(self.control, self._on_paint) #--------------------------------------------------------------------------- # Handles the 'image' trait being changed: #--------------------------------------------------------------------------- def _image_changed(self, image): self._image = self._mono_image = None if image is not None: self._img = image.create_image() self._image = self._img.ConvertToBitmap() if self.control is not None: self.control.Refresh() #--------------------------------------------------------------------------- # Handles the 'selected' trait being changed: #--------------------------------------------------------------------------- def _selected_changed(self, selected): """ Handles the 'selected' trait being changed. """ if selected and (self.style == 'radio'): for control in self.control.GetParent().GetChildren(): owner = getattr(control, '_owner', None) if (isinstance(owner, ImageButton) and owner.selected and (owner is not self)): owner.selected = False break self.control.Refresh() #-- wx event handlers ---------------------------------------------------------- def _on_enter_window(self, event): """ Called when the mouse enters the widget. """ if self.style != 'button': self._mouse_over = True self.control.Refresh() def _on_leave_window(self, event): """ Called when the mouse leaves the widget. """ if self._mouse_over: self._mouse_over = False self.control.Refresh() def _on_left_down(self, event): """ Called when the left mouse button goes down on the widget. """ self._button_down = True self.control.CaptureMouse() self.control.Refresh() def _on_left_up(self, event): """ Called when the left mouse button goes up on the widget. """ control = self.control control.ReleaseMouse() self._button_down = False wdx, wdy = control.GetClientSizeTuple() x, y = event.GetX(), event.GetY() control.Refresh() if (0 <= x < wdx) and (0 <= y < wdy): if self.style == 'radio': self.selected = True elif self.style == 'checkbox': self.selected = not self.selected else: self.clicked = True def _on_paint(self, event): """ Called when the widget needs repainting. """ wdc = wx.PaintDC(self.control) wdx, wdy = self.control.GetClientSizeTuple() ox = (wdx - self._dx) / 2 oy = (wdy - self._dy) / 2 disabled = (not self.control.IsEnabled()) if self._image is not None: image = self._image if disabled: if self._mono_image is None: img = self._img data = reshape(fromstring(img.GetData(), dtype('uint8')), (-1, 3)) * array([[0.297, 0.589, 0.114]]) g = data[:, 0] + data[:, 1] + data[:, 2] data[:, 0] = data[:, 1] = data[:, 2] = g img.SetData(ravel(data.astype(dtype('uint8'))).tostring()) img.SetMaskColour(0, 0, 0) self._mono_image = img.ConvertToBitmap() self._img = None image = self._mono_image wdc.DrawBitmap(image, ox + self._ix, oy + self._iy, True) if self.label != '': if disabled: wdc.SetTextForeground(DisabledTextColor) wdc.SetFont(wx.NORMAL_FONT) wdc.DrawText(self.label, ox + self._tx, oy + self._ty) pens = [self._selectedPenLight, self._selectedPenDark] bd = self._button_down style = self.style is_rc = (style in ('radio', 'checkbox')) if bd or (style == 'button') or (is_rc and self.selected): if is_rc: bd = 1 - bd wdc.SetBrush(wx.TRANSPARENT_BRUSH) wdc.SetPen(pens[bd]) wdc.DrawLine(1, 1, wdx - 1, 1) wdc.DrawLine(1, 1, 1, wdy - 1) wdc.DrawLine(2, 2, wdx - 2, 2) wdc.DrawLine(2, 2, 2, wdy - 2) wdc.SetPen(pens[1 - bd]) wdc.DrawLine(wdx - 2, 2, wdx - 2, wdy - 1) wdc.DrawLine(2, wdy - 2, wdx - 2, wdy - 2) wdc.DrawLine(wdx - 3, 3, wdx - 3, wdy - 2) wdc.DrawLine(3, wdy - 3, wdx - 3, wdy - 3) elif self._mouse_over and (not self.selected): wdc.SetBrush(wx.TRANSPARENT_BRUSH) wdc.SetPen(pens[bd]) wdc.DrawLine(0, 0, wdx, 0) wdc.DrawLine(0, 1, 0, wdy) wdc.SetPen(pens[1 - bd]) wdc.DrawLine(wdx - 1, 1, wdx - 1, wdy) wdc.DrawLine(1, wdy - 1, wdx - 1, wdy - 1)
class SpectrumOptions(AgeOptions): label = 'Spectrum' step_nsigma = Int(2) plot_option_klass = SpectrumPlotOptions edit_plateau_criteria = Button pc_nsteps = Int(3) pc_gas_fraction = Float(50) include_j_error_in_plateau = Bool(True) plateau_age_error_kind = Enum(*ERROR_TYPES) # plateau_steps = Property(Str) # _plateau_steps = Str # calculate_fixed_plateau = Bool(False) # calculate_fixed_plateau_start = Str # calculate_fixed_plateau_end = Str plot_option_name = 'Age' display_extract_value = Bool(False) display_step = Bool(False) display_plateau_info = Bool(True) display_integrated_info = Bool(True) plateau_sig_figs = Int integrated_sig_figs = Int plateau_font_size = Enum(6, 7, 8, 10, 11, 12, 14, 15, 18, 24, 28, 32) integrated_font_size = Enum(6, 7, 8, 10, 11, 12, 14, 15, 18, 24, 28, 32) step_label_font_size = Enum(6, 7, 8, 10, 11, 12, 14, 15, 18, 24, 28, 32) envelope_alpha = Range(0, 100, style='simple') envelope_color = Color user_envelope_color = Bool center_line_style = Enum('No Line', 'solid', 'dash', 'dot dash', 'dot', 'long dash') extend_plateau_end_caps = Bool(True) # plateau_line_width = Float # plateau_line_color = Color # user_plateau_line_color = Bool plateau_method = Enum('Fleck 1977', 'Mahon 1996') error_calc_method = Property use_error_envelope_fill = Bool include_plateau_sample = Bool include_plateau_identifier = Bool # edit_groups_button = Button group_editor_klass = SpectrumGroupEditor options_klass = SpectrumGroupOptions # handlers @on_trait_change('display_step,display_extract_value') def _handle_labels(self): labels_enabled = self.display_extract_value or self.display_step self.aux_plots[-1].show_labels = labels_enabled # # def _edit_groups_button_fired(self): # eg = SpectrumGroupEditor(error_envelopes=self.groups) # info = eg.edit_traits() # if info.result: # self.refresh_plot_needed = True def _edit_plateau_criteria_fired(self): v = View( Item('pc_nsteps', label='Num. Steps', tooltip='Number of contiguous steps'), Item('pc_gas_fraction', label='Min. Gas%', tooltip='Plateau must represent at least Min. Gas% release'), buttons=['OK', 'Cancel'], title='Edit Plateau Criteria', kind='livemodal') self.edit_traits(v) def _get_error_calc_method(self): return self.plateau_age_error_kind def _set_error_calc_method(self, v): self.plateau_age_error_kind = v # def _get_info_group(self): # g = VGroup( # HGroup(Item('show_info', label='Display Info'), # Item('show_mean_info', label='Mean', enabled_when='show_info'), # Item('show_error_type_info', label='Error Type', enabled_when='show_info') # ), # HGroup(Item('display_step'), Item('display_extract_value'), # Item('display_plateau_info')), # show_border=True, label='Info') # return g # def _get_plateau_steps(self): # return self._plateau_steps # # def _set_plateau_steps(self, v): # if v: # self._plateau_steps = v # # def _validate_plateau_steps(self, v): # if plat_regex.match(v): # s, e = v.split('-') # try: # assert s < e # return v # except AssertionError: # pass def _get_dump_attrs(self): attrs = super(SpectrumOptions, self)._get_dump_attrs() return attrs + [ 'step_nsigma', # 'calculate_fixed_plateau', # 'calculate_fixed_plateau_start', # 'calculate_fixed_plateau_end', 'display_extract_value', 'display_step', 'display_plateau_info', 'display_integrated_info', 'plateau_font_size', 'integrated_font_size', 'step_label_font_size', # 'envelope_alpha', # 'user_envelope_color', 'envelope_color', # 'groups', # '_plateau_steps', 'center_line_style', 'extend_plateau_end_caps', # 'plateau_line_width', # 'plateau_line_color', # 'user_plateau_line_color', 'include_j_error_in_plateau', 'plateau_age_error_kind', 'plateau_sig_figs', 'integrated_sig_figs', 'use_error_envelope_fill', 'plateau_method', 'pc_nsteps', 'pc_gas_fraction', 'legend_location', 'include_legend', 'include_sample_in_legend' ] def _get_groups(self): lgrp = VGroup( Item('plateau_method', label='Method'), Item('nsigma'), Item('plateau_age_error_kind', width=-100, label='Error Type'), Item('include_j_error_in_plateau', label='Include J Error')) rgrp = VGroup( Item('center_line_style', label='Line Stype'), Item('extend_plateau_end_caps', label='Extend End Caps'), icon_button_editor('edit_plateau_criteria', 'cog', tooltip='Edit Plateau Criteria'), ) plat_grp = HGroup(lgrp, rgrp, show_border=True, label='Plateau') # plat_grp = Group( # HGroup(Item('plateau_method', label='Method'), # icon_button_editor('edit_plateau_criteria', 'cog', # tooltip='Edit Plateau Criteria')), # Item('center_line_style'), # Item('extend_plateau_end_caps'), # # Item('plateau_line_width'), # # HGroup(UItem('user_plateau_line_color'), # # Item('plateau_line_color', enabled_when='user_plateau_line_color')), # # Item('nsigma'), # Item('plateau_age_error_kind', # width=-100, # label='Error Type'), # Item('include_j_error_in_plateau', label='Include J Error'), # # HGroup( # # Item('calculate_fixed_plateau', # # label='Calc. Plateau', # # tooltip='Calculate a plateau over provided steps'), # # Item('calculate_fixed_plateau_start', label='Start'), # # Item('calculate_fixed_plateau_end', label='End') # # ), # show_border=True, # label='Plateau') # error_grp = VGroup(HGroup(Item('step_nsigma', # editor=EnumEditor(values=[1, 2, 3]), # tooltip='Set the size of the error envelope in standard deviations', # label='N. Sigma'), # Item('use_error_envelope_fill', label='Fill')), # HGroup(UItem('user_envelope_color'), # Item('envelope_color', # label='Color', # enabled_when='user_envelope_color'), # Item('envelope_alpha', # label='Opacity', # enabled_when='use_error_envelope_fill', # tooltip='Set the opacity (alpha-value) for the error envelope')), # show_border=True, # label='Error Envelope') # grp_grp = VGroup(HGroup(UItem('group', # style='custom', # editor=InstanceEditor(view='simple_view')), # icon_button_editor('edit_groups_button', 'cog')), # show_border=True, # label='Group Attributes') grp_grp = VGroup(UItem('group', style='custom', editor=InstanceEditor(view='simple_view')), show_border=True, label='Group Attributes') error_grp = VGroup( HGroup( Item( 'step_nsigma', editor=EnumEditor(values=[1, 2, 3]), tooltip= 'Set the size of the error envelope in standard deviations', label='N. Sigma')), # HGroup(UItem('error_envelope', # style='custom', # editor=InstanceEditor(view='simple_view')), # icon_button_editor('edit_envelopes_button', 'cog')), # HGroup(UItem('user_envelope_color'), # Item('envelope_color', # label='Color', # enabled_when='user_envelope_color'), # Item('envelope_alpha', # label='Opacity', # enabled_when='use_error_envelope_fill', # tooltip='Set the opacity (alpha-value) for the error envelope')), show_border=True, label='Error Envelope') display_grp = Group( HGroup(UItem( 'show_info', tooltip='Show general info in the upper right corner'), show_border=True, label='General'), VGroup(Item('include_legend', label='Show'), Item('include_sample_in_legend', label='Include Sample'), Item('legend_location', label='Location'), label='Legend', show_border=True), HGroup(Item('display_step', label='Step'), Item('display_extract_value', label='Power/Temp'), spring, Item('step_label_font_size', label='Size'), show_border=True, label='Labels'), VGroup( HGroup( UItem('display_plateau_info', tooltip='Display plateau info'), Item('plateau_font_size', label='Size', enabled_when='display_plateau_info'), Item('plateau_sig_figs', label='SigFigs')), HGroup( Item( 'include_plateau_sample', tooltip='Add the Sample name to the Plateau indicator', label='Sample'), Item('include_plateau_identifier', tooltip='Add the Identifier to the Plateau indicator', label='Identifier')), show_border=True, label='Plateau'), HGroup(UItem('display_integrated_info', tooltip='Display integrated age info'), Item('integrated_font_size', label='Size', enabled_when='display_integrated_info'), Item('integrated_sig_figs', label='SigFigs'), show_border=True, label='Integrated'), show_border=True, label='Display') g = Group( self._get_title_group(), grp_grp, plat_grp, error_grp, display_grp, # self._get_info_group(), label='Options') # label_grp = VGroup(self._get_x_axis_group(), # self._get_y_axis_group(), # label='Fonts') return g, def _load_factory_defaults(self, yd): super(SpectrumOptions, self)._load_factory_defaults(yd) self._set_defaults( yd, 'legend', ('legend_location', 'include_legend', 'include_sample_in_legend')) self._set_defaults( yd, 'plateau', ( 'plateau_line_width', 'plateau_line_color', 'plateau_font_size', 'plateau_sig_figs', # 'calculate_fixed_plateau', # 'calculate_fixed_plateau_start', # 'calculate_fixed_plateau_end', 'pc_nsteps', 'pc_gas_fraction')) self._set_defaults(yd, 'integrated', ( 'integrated_font_size', 'integrated_sig_figs', )) self._set_defaults( yd, 'labels', ('display_step', 'display_extract_value', 'step_label_font_size'))
class BarPlot(AbstractPlotRenderer): """ A renderer for bar charts. """ # The data source to use for the index coordinate. index = Instance(ArrayDataSource) # The data source to use as value points. value = Instance(ArrayDataSource) # The data source to use as "starting" values for the bars. # For instance, if the values are [10, 20] and starting_value # is [3, 7], BarPlot will plot two bars, one between 3 and 10, and # one between 7 and 20 starting_value = Instance(ArrayDataSource) # Labels for the indices. index_mapper = Instance(AbstractMapper) # Labels for the values. value_mapper = Instance(AbstractMapper) # The orientation of the index axis. orientation = Enum("h", "v") # The direction of the index axis with respect to the graphics context's # direction. index_direction = Enum("normal", "flipped") # The direction of the value axis with respect to the graphics context's # direction. value_direction = Enum("normal", "flipped") # Type of width used for bars: # # 'data' # The width is in the units along the x-dimension of the data space. # 'screen' # The width uses a fixed width of pixels. bar_width_type = Enum("data", "screen") # Width of the bars, in data or screen space (determined by # **bar_width_type**). bar_width = Float(10) # Round on rectangle dimensions? This is not strictly an "antialias", but # it has the same effect through exact pixel drawing. antialias = Bool(True) # Width of the border of the bars. line_width = Float(1.0) # Color of the border of the bars. line_color = black_color_trait # Color to fill the bars. fill_color = black_color_trait # The RGBA tuple for rendering lines. It is always a tuple of length 4. # It has the same RGB values as line_color_, and its alpha value is the # alpha value of self.line_color multiplied by self.alpha. effective_line_color = Property(Tuple, depends_on=['line_color', 'alpha']) # The RGBA tuple for rendering the fill. It is always a tuple of length 4. # It has the same RGB values as fill_color_, and its alpha value is the # alpha value of self.fill_color multiplied by self.alpha. effective_fill_color = Property(Tuple, depends_on=['fill_color', 'alpha']) # Overall alpha value of the image. Ranges from 0.0 for transparent to 1.0 alpha = Range(0.0, 1.0, 1.0) #use_draw_order = False # Convenience properties that correspond to either index_mapper or # value_mapper, depending on the orientation of the plot. # Corresponds to either **index_mapper** or **value_mapper**, depending on # the orientation of the plot. x_mapper = Property # Corresponds to either **value_mapper** or **index_mapper**, depending on # the orientation of the plot. y_mapper = Property # Corresponds to either **index_direction** or **value_direction**, # depending on the orientation of the plot. x_direction = Property # Corresponds to either **value_direction** or **index_direction**, # depending on the orientation of the plot y_direction = Property # Convenience property for accessing the index data range. index_range = Property # Convenience property for accessing the value data range. value_range = Property #------------------------------------------------------------------------ # Private traits #------------------------------------------------------------------------ # Indicates whether or not the data cache is valid _cache_valid = Bool(False) # Cached data values from the datasources. If **bar_width_type** is "data", # then this is an Nx4 array of (bar_left, bar_right, start, end) for a # bar plot in normal orientation. If **bar_width_type** is "screen", then # this is an Nx3 array of (bar_center, start, end). _cached_data_pts = Any #------------------------------------------------------------------------ # AbstractPlotRenderer interface #------------------------------------------------------------------------ def __init__(self, *args, **kw): # These Traits depend on others, so we'll defer setting them until # after the HasTraits initialization has been completed. later_list = ['index_direction', 'value_direction'] postponed = {} for name in later_list: if name in kw: postponed[name] = kw.pop(name) super(BarPlot, self).__init__(*args, **kw) # Set any keyword Traits that were postponed. self.set(**postponed) def map_screen(self, data_array): """ Maps an array of data points into screen space and returns it as an array. Implements the AbstractPlotRenderer interface. """ # data_array is Nx2 array if len(data_array) == 0: return [] x_ary, y_ary = transpose(data_array) sx = self.index_mapper.map_screen(x_ary) sy = self.value_mapper.map_screen(y_ary) if self.orientation == "h": return transpose(array((sx, sy))) else: return transpose(array((sy, sx))) def map_data(self, screen_pt): """ Maps a screen space point into the "index" space of the plot. Implements the AbstractPlotRenderer interface. """ if self.orientation == "h": screen_coord = screen_pt[0] else: screen_coord = screen_pt[1] return self.index_mapper.map_data(screen_coord) def map_index(self, screen_pt, threshold=2.0, outside_returns_none=True, index_only=False): """ Maps a screen space point to an index into the plot's index array(s). Implements the AbstractPlotRenderer interface. """ data_pt = self.map_data(screen_pt) if ((data_pt < self.index_mapper.range.low) or \ (data_pt > self.index_mapper.range.high)) and outside_returns_none: return None index_data = self.index.get_data() value_data = self.value.get_data() if len(value_data) == 0 or len(index_data) == 0: return None try: ndx = reverse_map_1d(index_data, data_pt, self.index.sort_order) except IndexError: return None x = index_data[ndx] y = value_data[ndx] result = self.map_screen(array([[x, y]])) if result is None: return None sx, sy = result[0] if index_only and ((screen_pt[0] - sx) < threshold): return ndx elif ((screen_pt[0] - sx)**2 + (screen_pt[1] - sy)**2 < threshold * threshold): return ndx else: return None #------------------------------------------------------------------------ # PlotComponent interface #------------------------------------------------------------------------ def _gather_points(self): """ Collects data points that are within the range of the plot, and caches them in **_cached_data_pts**. """ index, index_mask = self.index.get_data_mask() value, value_mask = self.value.get_data_mask() if not self.index or not self.value: return if len(index) == 0 or len(value) == 0 or len(index) != len(value): logger.warn("Chaco: using empty dataset; index_len=%d, value_len=%d." \ % (len(index), len(value))) self._cached_data_pts = array([]) self._cache_valid = True return # TODO: Until we code up a better handling of value-based culling that # takes into account starting_value and dataspace bar widths, just use # the index culling for now. # value_range_mask = self.value_mapper.range.mask_data(value) # nan_mask = invert(isnan(index_mask)) & invert(isnan(value_mask)) # point_mask = index_mask & value_mask & nan_mask & \ # index_range_mask & value_range_mask index_range_mask = self.index_mapper.range.mask_data(index) nan_mask = invert(isnan(index_mask)) point_mask = index_mask & nan_mask & index_range_mask if self.starting_value is None: starting_values = zeros(len(index)) else: starting_values = self.starting_value.get_data() if self.bar_width_type == "data": half_width = self.bar_width / 2.0 points = column_stack((index - half_width, index + half_width, starting_values, value)) else: points = column_stack((index, starting_values, value)) self._cached_data_pts = compress(point_mask, points, axis=0) self._cache_valid = True return def _draw_plot(self, gc, view_bounds=None, mode="normal"): """ Draws the 'plot' layer. """ if not self._cache_valid: self._gather_points() data = self._cached_data_pts if data.size == 0: # Nothing to draw. return with gc: gc.clip_to_rect(self.x, self.y, self.width, self.height) gc.set_antialias(self.antialias) gc.set_stroke_color(self.effective_line_color) gc.set_fill_color(self.effective_fill_color) gc.set_line_width(self.line_width) if self.bar_width_type == "data": # map the bar start and stop locations into screen space lower_left_pts = self.map_screen(data[:, (0, 2)]) upper_right_pts = self.map_screen(data[:, (1, 3)]) else: half_width = self.bar_width / 2.0 # map the bar centers into screen space and then compute the bar # start and end positions lower_left_pts = self.map_screen(data[:, (0, 1)]) upper_right_pts = self.map_screen(data[:, (0, 2)]) lower_left_pts[:, 0] -= half_width upper_right_pts[:, 0] += half_width bounds = upper_right_pts - lower_left_pts gc.rects(column_stack((lower_left_pts, bounds))) gc.draw_path() def _draw_default_axes(self, gc): if not self.origin_axis_visible: return with gc: gc.set_stroke_color(self.origin_axis_color_) gc.set_line_width(self.origin_axis_width) gc.set_line_dash(None) for range in (self.index_mapper.range, self.value_mapper.range): if (range.low < 0) and (range.high > 0): if range == self.index_mapper.range: dual = self.value_mapper.range data_pts = array([[0.0, dual.low], [0.0, dual.high]]) else: dual = self.index_mapper.range data_pts = array([[dual.low, 0.0], [dual.high, 0.0]]) start, end = self.map_screen(data_pts) gc.move_to(int(start[0]) + 0.5, int(start[1]) + 0.5) gc.line_to(int(end[0]) + 0.5, int(end[1]) + 0.5) gc.stroke_path() return def _render_icon(self, gc, x, y, width, height): with gc: gc.set_fill_color(self.effective_fill_color) gc.set_stroke_color(self.effective_line_color) gc.rect(x + width / 4, y + height / 4, width / 2, height / 2) gc.draw_path(FILL_STROKE) def _post_load(self): super(BarPlot, self)._post_load() return #------------------------------------------------------------------------ # Properties #------------------------------------------------------------------------ def _get_index_range(self): return self.index_mapper.range def _set_index_range(self, val): self.index_mapper.range = val def _get_value_range(self): return self.value_mapper.range def _set_value_range(self, val): self.value_mapper.range = val def _get_x_mapper(self): if self.orientation == "h": return self.index_mapper else: return self.value_mapper def _get_y_mapper(self): if self.orientation == "h": return self.value_mapper else: return self.index_mapper def _get_x_direction(self): if self.orientation == "h": return self.index_direction else: return self.value_direction def _get_y_direction(self): if self.orientation == "h": return self.value_direction else: return self.index_direction #------------------------------------------------------------------------ # Event handlers - these are mostly copied from BaseXYPlot #------------------------------------------------------------------------ def _update_mappers(self): """ Updates the index and value mappers. Called by trait change handlers for various traits. """ x_mapper = self.index_mapper y_mapper = self.value_mapper x_dir = self.index_direction y_dir = self.value_direction if self.orientation == "v": x_mapper, y_mapper = y_mapper, x_mapper x_dir, y_dir = y_dir, x_dir x = self.x x2 = self.x2 y = self.y y2 = self.y2 if x_mapper is not None: if x_dir == "normal": x_mapper.low_pos = x x_mapper.high_pos = x2 else: x_mapper.low_pos = x2 x_mapper.high_pos = x if y_mapper is not None: if y_dir == "normal": y_mapper.low_pos = y y_mapper.high_pos = y2 else: y_mapper.low_pos = y2 y_mapper.high_pos = y self.invalidate_draw() self._cache_valid = False @on_trait_change('line_color, line_width, fill_color, alpha') def _attributes_changed(self): self.invalidate_draw() self.request_redraw() def _bounds_changed(self, old, new): super(BarPlot, self)._bounds_changed(old, new) self._update_mappers() def _bounds_items_changed(self, event): super(BarPlot, self)._bounds_items_changed(event) self._update_mappers() def _orientation_changed(self): self._update_mappers() def _index_changed(self, old, new): if old is not None: old.on_trait_change(self._either_data_changed, "data_changed", remove=True) if new is not None: new.on_trait_change(self._either_data_changed, "data_changed") self._either_data_changed() def _index_direction_changed(self): m = self.index_mapper m.low_pos, m.high_pos = m.high_pos, m.low_pos self.invalidate_draw() def _value_direction_changed(self): m = self.value_mapper m.low_pos, m.high_pos = m.high_pos, m.low_pos self.invalidate_draw() def _either_data_changed(self): self.invalidate_draw() self._cache_valid = False self.request_redraw() def _value_changed(self, old, new): if old is not None: old.on_trait_change(self._either_data_changed, "data_changed", remove=True) if new is not None: new.on_trait_change(self._either_data_changed, "data_changed") self._either_data_changed() def _index_mapper_changed(self, old, new): return self._either_mapper_changed(old, new) def _value_mapper_changed(self, old, new): return self._either_mapper_changed(old, new) def _either_mapper_changed(self, old, new): if old is not None: old.on_trait_change(self._mapper_updated_handler, "updated", remove=True) if new is not None: new.on_trait_change(self._mapper_updated_handler, "updated") self.invalidate_draw() def _mapper_updated_handler(self): self._cache_valid = False self.invalidate_draw() self.request_redraw() def _bar_width_changed(self): self._cache_valid = False self.invalidate_draw() self.request_redraw() def _bar_width_type_changed(self): self._cache_valid = False self.invalidate_draw() self.request_redraw() #------------------------------------------------------------------------ # Property getters #------------------------------------------------------------------------ @cached_property def _get_effective_line_color(self): if len(self.line_color_) == 4: line_alpha = self.line_color_[-1] else: line_alpha = 1.0 c = self.line_color_[:3] + (line_alpha * self.alpha, ) return c @cached_property def _get_effective_fill_color(self): if len(self.fill_color_) == 4: fill_alpha = self.fill_color_[-1] else: fill_alpha = 1.0 c = self.fill_color_[:3] + (fill_alpha * self.alpha, ) return c
class AnnotationEditor(HasTraits): component = Any border_visible = Bool(True) border_width = Range(0, 10) border_color = Color font = Font('modern 12') text_color = Color bgcolor = Color text = Str bg_visible = Bool(True) @on_trait_change('component:text') def _component_text_changed(self): self.text = self.component.text def _component_changed(self): if self.component: traits = ('border_visible', 'border_width', 'text') d = self.component.trait_get(traits) self.trait_set(self, **d) for c in ('border_color', 'text_color', 'bgcolor'): v = getattr(self.component, c) if not isinstance(v, str): v = v[0] * 255, v[1] * 255, v[2] * 255 self.trait_set(**{c: v}) def _bg_visible_changed(self): if self.component: if self.bg_visible: self.component.bgcolor = self.bgcolor else: self.component.bgcolor = transparent_color self.component.request_redraw() @on_trait_change('border_+, text_color, bgcolor, text') def _update(self, name, new): if self.component: self.component.trait_set(**{name: new}) self.component.request_redraw() @on_trait_change('font') def _update_font(self): if self.component: self.component.font = str(self.font) self.component.request_redraw() def traits_view(self): v = View( VGroup(Item( 'font', width=75, ), Item('text_color', label='Text'), HGroup( UItem('bg_visible', tooltip='Is the background transparent'), Item('bgcolor', label='Background', enabled_when='bg_visible'), ), UItem('text', style='custom'), Group(Item('border_visible'), Item('border_width', enabled_when='border_visible'), Item('border_color', enabled_when='border_visible'), label='Border'), visible_when='component')) return v
class Elementary1DRule(NDimRule): """ Rule implementing an elementary 1D cellular automata. This uses Wolfram's rule numbering scheme to identify the rules and scipy.ndimage to handle the boundary conditions. Notes ----- See `Wikipedia <https://en.wikipedia.org/wiki/Elementary_cellular_automaton>`_ for further information on how these automata work. """ # Elementary1DRule Traits ------------------------------------------------ #: The number of the rule. rule_number = Range(0, 255) #: The state value for "empty" cells. empty_state = StateValue(0) #: The state value for "filled" cells. filled_state = StateValue(1) #: The bit-mask corresponding to the rule. bit_mask = Property(Array(shape=(8,), dtype=bool), depends_on='rule_number') #: The boundary mode to use. boundary = Enum('empty', 'filled', 'nearest', 'wrap', 'reflect') # NDimRule Traits -------------------------------------------------------- #: These are 1-dimensional only rules. ndim = Constant(1) # ------------------------------------------------------------------------ # Elementary1DRule interface # ------------------------------------------------------------------------ def reflect(self): """ Reflect the cellular automata left-to-right. """ self.bit_mask = self.bit_mask[REVERSE_PERMUTATION] def complement(self): """ Complement the cellular automata replacing 1's with 0's throughout. """ self.bit_mask = ~self.bit_mask[::-1] # ------------------------------------------------------------------------ # AbstractRule interface # ------------------------------------------------------------------------ def step(self, states): """ Apply the specified rule to the states. Parameters ---------- states : array An array holding the current states of the automata. Returns ------- states : array The new states of the automata after the rule has been applied. """ states = super(NDimRule, self).step(states) wrap_args = {'mode': self.boundary} if self.boundary == 'empty': wrap_args['mode'] = 'constant' wrap_args['cval'] = 0 elif self.boundary == 'filled': wrap_args['mode'] = 'constant' wrap_args['cval'] = 1 filled = (states == self.filled_state) filled = ndimage.generic_filter1d( filled, self._rule_filter, filter_size=3, **wrap_args) states = np.full(filled.shape, self.empty_state, dtype='uint8') states[filled] = self.filled_state return states # ------------------------------------------------------------------------ # Private interface # ------------------------------------------------------------------------ def _rule_filter(self, iline, oline): """ Kernel to compute values in generic filter """ index = (iline[:-2] * 4 + iline[1:-1] * 2 + iline[2:]).astype('uint8') oline[...] = self.bit_mask[index] # Trait properties ------------------------------------------------------- @cached_property def _get_bit_mask(self): bits = np.unpackbits(np.array([self.rule_number], dtype='uint8'))[::-1] return bits.astype(bool) def _set_bit_mask(self, bits): bits = np.asarray(bits, dtype=bool) self.rule_number = int(np.packbits(bits[::-1])[0])
class AnOddClass(HasTraits): oddball = Trait(1, TraitOddInteger()) very_odd = Trait(-1, TraitOddInteger(), Range(-10, -1))
class Worker(HasTraits): """This class basically allows you to create a data set, view it and modify the dataset. This is a rather crude example but demonstrates how things can be done. """ # Set by envisage when this is contributed as a ServiceOffer. window = Instance('pyface.workbench.api.WorkbenchWindow') create_data = Button('Create data') reset_data = Button('Reset data') view_data = Button('View data') scale = Range(0.0, 1.0) source = Instance('mayavi.core.source.Source') # Our UI view. view = View(Item('create_data', show_label=False), Item('view_data', show_label=False), Item('reset_data', show_label=False), Item('scale'), resizable=True) def get_mayavi(self): from mayavi.plugins.script import Script return self.window.get_service(Script) def _make_data(self): dims = [64, 64, 64] np = dims[0] * dims[1] * dims[2] x, y, z = numpy.ogrid[-5:5:dims[0] * 1j, -5:5:dims[1] * 1j, -5:5:dims[2] * 1j] x = x.astype('f') y = y.astype('f') z = z.astype('f') s = (numpy.sin(x * y * z) / (x * y * z)) s = s.transpose().copy() # This makes the data contiguous. return s def _create_data_fired(self): mayavi = self.get_mayavi() from mayavi.sources.array_source import ArraySource s = self._make_data() src = ArraySource(transpose_input_array=False, scalar_data=s) self.source = src mayavi.add_source(src) def _reset_data_fired(self): self.source.scalar_data = self._make_data() def _view_data_fired(self): mayavi = self.get_mayavi() from mayavi.modules.outline import Outline from mayavi.modules.image_plane_widget import ImagePlaneWidget # Visualize the data. o = Outline() mayavi.add_module(o) ipw = ImagePlaneWidget() mayavi.add_module(ipw) ipw.module_manager.scalar_lut_manager.show_scalar_bar = True ipw_y = ImagePlaneWidget() mayavi.add_module(ipw_y) ipw_y.ipw.plane_orientation = 'y_axes' def _scale_changed(self, value): src = self.source data = src.scalar_data data += value * 0.01 numpy.mod(data, 1.0, data) src.update()
class Reactor(HasTraits): core_temperature = Range(-273.0, 100000.0) water_level = Float
class SelectOutput(Filter): """This filter lets a user select one among several of the output blocks of a given input. This is very useful for a multi-block data source. """ # The output index in the input to choose from. output_index = Range(value=0, enter_set=True, auto_set=False, low='_min_index', high='_max_index') input_info = PipelineInfo(datasets=['any'], attribute_types=['any'], attributes=['any']) output_info = PipelineInfo(datasets=['any'], attribute_types=['any'], attributes=['any']) # The minimum output index of our input. _min_index = Int(0, desc='the minimum output index') # The maximum output index of our input. _max_index = Int(0, desc='the maximum output index') _my_input = Any(transient=True) ######################################## # Traits View. view = View(Group(Item('output_index', enabled_when='_max_index > 0')), resizable=True) ###################################################################### # `object` interface. def __get_pure_state__(self): d = super(SelectOutput, self).__get_pure_state__() d['output_index'] = self.output_index return d def __set_pure_state__(self, state): super(SelectOutput, self).__set_pure_state__(state) # Force an update of the output index -- if not this doesn't # change. self._output_index_changed(state.output_index) ###################################################################### # `Filter` interface. def update_pipeline(self): # Do nothing if there is no input. inputs = self.inputs if len(inputs) == 0: return # Set the maximum index. obj = get_new_output(inputs[0].outputs[0]) self._my_input = obj if hasattr(obj, 'number_of_blocks'): self._max_index = obj.number_of_blocks - 1 else: self._max_index = len(inputs[0].outputs) - 1 self._output_index_changed(self.output_index) def update_data(self): # Propagate the event. self.data_changed = True ###################################################################### # Trait handlers. def _setup_output(self, value): obj = self._my_input tp = tvtk.TrivialProducer() if hasattr(obj, 'number_of_blocks'): tp.set_output(obj.get_block(value)) else: tp.set_output(self.inputs[0].outputs[value]) self._set_outputs([tp]) def _output_index_changed(self, value): """Static trait handler.""" if value > self._max_index: self.output_index = self._max_index elif value < self._min_index: self.output_index = self._min_index else: self._setup_output(value) s = self.scene if s is not None: s.renderer.reset_camera_clipping_range() s.render()
class VUMeter(Component): # Value expressed in dB db = Property(Float) # Value expressed as a percent. percent = Range(low=0.0) # The maximum value to be display in the VU Meter, expressed as a percent. max_percent = Float(150.0) # Angle (in degrees) from a horizontal line through the hinge of the # needle to the edge of the meter axis. angle = Float(45.0) # Values of the percentage-based ticks; these are drawn and labeled along # the bottom of the curve axis. percent_ticks = List(list(range(0, 101, 20))) # Text to write in the middle of the VU Meter. text = Str("VU") # Font used to draw `text`. text_font = KivaFont("modern 48") # Font for the db tick labels. db_tick_font = KivaFont("modern 16") # Font for the percent tick labels. percent_tick_font = KivaFont("modern 12") # beta is the fraction of the of needle that is "hidden". # beta == 0 puts the hinge point of the needle on the bottom # edge of the window. Values that result in a decent looking # meter are 0 < beta < .65. # XXX needs a better name! _beta = Float(0.3) # _outer_radial_margin is the radial extent beyond the circular axis # to include in calculations of the space required for the meter. # This allows room for the ticks and labels. _outer_radial_margin = Float(60.0) # The angle (in radians) of the span of the curve axis. _phi = Property(Float, observe=["angle"]) # This is the radius of the circular axis (in screen coordinates). _axis_radius = Property(Float, observe=["_phi", "width", "height"]) # --------------------------------------------------------------------- # Trait Property methods # --------------------------------------------------------------------- def _get_db(self): db = percent_to_db(self.percent) return db def _set_db(self, value): self.percent = db_to_percent(value) def _get__phi(self): phi = math.pi * (180.0 - 2 * self.angle) / 180.0 return phi def _get__axis_radius(self): M = self._outer_radial_margin beta = self._beta w = self.width h = self.height phi = self._phi R1 = w / (2 * math.sin(phi / 2)) - M R2 = (h - M) / (1 - beta * math.cos(phi / 2)) R = min(R1, R2) return R # --------------------------------------------------------------------- # Trait change handlers # --------------------------------------------------------------------- def _anytrait_changed(self): self.request_redraw() # --------------------------------------------------------------------- # Component API # --------------------------------------------------------------------- def _draw_mainlayer(self, gc, view_bounds=None, mode="default"): beta = self._beta phi = self._phi w = self.width M = self._outer_radial_margin R = self._axis_radius # (ox, oy) is the position of the "hinge point" of the needle # (i.e. the center of rotation). For beta > ~0, oy is negative, # so this point is below the visible region. ox = self.x + self.width // 2 oy = -beta * R * math.cos(phi / 2) + 1 left_theta = math.radians(180 - self.angle) right_theta = math.radians(self.angle) # The angle of the 100% position. nominal_theta = self._percent_to_theta(100.0) # The color of the axis for percent > 100. red = (0.8, 0, 0) with gc: gc.set_antialias(True) # Draw everything relative to the center of the circles. gc.translate_ctm(ox, oy) # Draw the primary ticks and tick labels on the curved axis. gc.set_fill_color((0, 0, 0)) gc.set_font(self.db_tick_font) for db in [-20, -10, -7, -5, -3, -2, -1, 0, 1, 2, 3]: db_percent = db_to_percent(db) theta = self._percent_to_theta(db_percent) x1 = R * math.cos(theta) y1 = R * math.sin(theta) x2 = (R + 0.3 * M) * math.cos(theta) y2 = (R + 0.3 * M) * math.sin(theta) gc.set_line_width(2.5) gc.move_to(x1, y1) gc.line_to(x2, y2) gc.stroke_path() text = str(db) if db > 0: text = "+" + text self._draw_rotated_label(gc, text, theta, R + 0.4 * M) # Draw the secondary ticks on the curve axis. for db in [-15, -9, -8, -6, -4, -0.5, 0.5]: # db_percent = 100 * math.pow(10.0, db / 20.0) db_percent = db_to_percent(db) theta = self._percent_to_theta(db_percent) x1 = R * math.cos(theta) y1 = R * math.sin(theta) x2 = (R + 0.2 * M) * math.cos(theta) y2 = (R + 0.2 * M) * math.sin(theta) gc.set_line_width(1.0) gc.move_to(x1, y1) gc.line_to(x2, y2) gc.stroke_path() # Draw the percent ticks and label on the bottom of the # curved axis. gc.set_font(self.percent_tick_font) gc.set_fill_color((0.5, 0.5, 0.5)) gc.set_stroke_color((0.5, 0.5, 0.5)) percents = self.percent_ticks for tick_percent in percents: theta = self._percent_to_theta(tick_percent) x1 = (R - 0.15 * M) * math.cos(theta) y1 = (R - 0.15 * M) * math.sin(theta) x2 = R * math.cos(theta) y2 = R * math.sin(theta) gc.set_line_width(2.0) gc.move_to(x1, y1) gc.line_to(x2, y2) gc.stroke_path() text = str(tick_percent) if tick_percent == percents[-1]: text = text + "%" self._draw_rotated_label(gc, text, theta, R - 0.3 * M) if self.text: gc.set_font(self.text_font) tx, ty, tw, th = gc.get_text_extent(self.text) gc.set_fill_color((0, 0, 0, 0.25)) gc.set_text_matrix(affine.affine_from_rotation(0)) gc.set_text_position(-0.5 * tw, (0.75 * beta + 0.25) * R) gc.show_text(self.text) # Draw the red curved axis. gc.set_stroke_color(red) w = 10 gc.set_line_width(w) gc.arc(0, 0, R + 0.5 * w - 1, right_theta, nominal_theta) gc.stroke_path() # Draw the black curved axis. w = 4 gc.set_line_width(w) gc.set_stroke_color((0, 0, 0)) gc.arc(0, 0, R + 0.5 * w - 1, nominal_theta, left_theta) gc.stroke_path() # Draw the filled arc at the bottom. gc.set_line_width(2) gc.set_stroke_color((0, 0, 0)) gc.arc( 0, 0, beta * R, math.radians(self.angle), math.radians(180 - self.angle), ) gc.stroke_path() gc.set_fill_color((0, 0, 0, 0.25)) gc.arc( 0, 0, beta * R, math.radians(self.angle), math.radians(180 - self.angle), ) gc.fill_path() # Draw the needle. percent = self.percent # If percent exceeds max_percent, the needle is drawn at # max_percent. if percent > self.max_percent: percent = self.max_percent needle_theta = self._percent_to_theta(percent) gc.rotate_ctm(needle_theta - 0.5 * math.pi) self._draw_vertical_needle(gc) # --------------------------------------------------------------------- # Private methods # --------------------------------------------------------------------- def _draw_vertical_needle(self, gc): """ Draw the needle of the meter, pointing straight up. """ beta = self._beta R = self._axis_radius end_y = beta * R blob_y = R - 0.6 * self._outer_radial_margin tip_y = R + 0.2 * self._outer_radial_margin lw = 5 with gc: gc.set_alpha(1) gc.set_fill_color((0, 0, 0)) # Draw the needle from the bottom to the blob. gc.set_line_width(lw) gc.move_to(0, end_y) gc.line_to(0, blob_y) gc.stroke_path() # Draw the thin part of the needle from the blob to the tip. gc.move_to(lw, blob_y) control_y = blob_y + 0.25 * (tip_y - blob_y) gc.quad_curve_to(0.2 * lw, control_y, 0, tip_y) gc.quad_curve_to(-0.2 * lw, control_y, -lw, blob_y) gc.line_to(lw, blob_y) gc.fill_path() # Draw the blob on the needle. gc.arc(0, blob_y, 6.0, 0, 2 * math.pi) gc.fill_path() def _draw_rotated_label(self, gc, text, theta, radius): tx, ty, tw, th = gc.get_text_extent(text) rr = math.sqrt(radius ** 2 + (0.5 * tw) ** 2) dtheta = math.atan2(0.5 * tw, radius) text_theta = theta + dtheta x = rr * math.cos(text_theta) y = rr * math.sin(text_theta) rot_theta = theta - 0.5 * math.pi with gc: gc.set_text_matrix(affine.affine_from_rotation(rot_theta)) gc.set_text_position(x, y) gc.show_text(text) def _percent_to_theta(self, percent): """ Convert percent to the angle theta, in radians. theta is the angle of the needle measured counterclockwise from the horizontal (i.e. the traditional angle of polar coordinates). """ angle = ( self.angle + (180.0 - 2 * self.angle) * (self.max_percent - percent) / self.max_percent ) theta = math.radians(angle) return theta def _db_to_theta(self, db): """ Convert db to the angle theta, in radians. """ percent = db_to_percent(db) theta = self._percent_to_theta(percent) return theta
class Test(HasTraits): var = Range(low=[]) # E: arg-type var2 = Range(low="3") var3 = Range(low=3)
class RemapDemo(ImageProcessDemo): TITLE = "Remap Demo" DEFAULT_IMAGE = "lena.jpg" offsetx = Array offsety = Array gridx = Array gridy = Array need_redraw = Event need_update_map = Bool radius = Range(10, 50, 20) sigma = Range(10, 50, 30) history = List() history_len = Property(depends_on="history") def __init__(self, **kw): super(RemapDemo, self).__init__(**kw) self.figure.canvas_events = [ ("button_press_event", self.on_figure_press), ("button_release_event", self.on_figure_release), ("motion_notify_event", self.on_figure_motion), ('scroll_event', self.on_scroll), ] self.connect_dirty("need_redraw") self.target_pos = self.source_pos = None def control_panel(self): return VGroup( "radius", "sigma", Item("history_len", style="readonly"), ) def _get_history_len(self): return len(self.history) def on_scroll(self, event): try: if event.key == "shift": self.sigma = int(self.sigma + event.step) else: self.radius = int(self.radius + event.step) except TraitError: pass def on_figure_press(self, event): if event.inaxes is self.axe and event.button == 1: self.target_pos = int(event.xdata), int(event.ydata) elif event.button == 3: if self.history: self.img[:] = self.history.pop()[:] self.need_redraw = True def on_figure_release(self, event): if event.inaxes is self.axe and event.button == 1: self.source_pos = int(event.xdata), int(event.ydata) self.need_redraw = True self.need_update_map = True def on_figure_motion(self, event): if event.inaxes is self.axe and event.button == 1: self.source_pos = int(event.xdata), int(event.ydata) self.need_redraw = True def _img_changed(self): self.offsetx = np.zeros(self.img.shape[:2], dtype=np.float32) self.offsety = np.zeros(self.img.shape[:2], dtype=np.float32) self.gridy, self.gridx = np.mgrid[:self.img.shape[0], :self.img. shape[1]] self.gridx = self.gridx.astype(np.float32) self.gridy = self.gridy.astype(np.float32) def draw(self): self.offsetx.fill(0) self.offsety.fill(0) if self.source_pos is not None and self.target_pos is not None: sx, sy = self.source_pos tx, ty = self.target_pos mask = ((self.gridx - sx)**2 + (self.gridy - sy)**2) < self.radius**2 self.offsetx[mask] = tx - sx self.offsety[mask] = ty - sy cv2.GaussianBlur(self.offsetx, (0, 0), self.sigma, self.offsetx) cv2.GaussianBlur(self.offsety, (0, 0), self.sigma, self.offsety) img2 = cv2.remap(self.img, self.offsetx + self.gridx, self.offsety + self.gridy, cv2.INTER_LINEAR) self.draw_image(img2) if self.need_update_map: self.history.append(self.img.copy()) if len(self.history) > 10: del self.history[0] self.img[:] = img2[:] self.source_pos = self.target_pos = None self.need_update_map = False
class VolumeFactory(PipeFactory): """ Applies the Volume mayavi module to the given VTK data source (Mayavi source, or VTK dataset). **Note** The range of the colormap can be changed simply using the vmin/vmax parameters (see below). For more complex modifications of the colormap, here is some pseudo code to change the ctf (color transfer function), or the otf (opacity transfer function):: vol = mlab.pipeline.volume(src) # Changing the ctf: from tvtk.util.ctf import ColorTransferFunction ctf = ColorTransferFunction() ctf.add_rgb_point(value, r, g, b) # r, g, and b are float # between 0 and 1 ctf.add_hsv_point(value, h, s, v) # ... vol._volume_property.set_color(ctf) vol._ctf = ctf vol.update_ctf = True # Changing the otf: from tvtk.util.ctf import PiecewiseFunction otf = PiecewiseFunction() otf.add_point(value, opacity) vol._otf = otf vol._volume_property.set_scalar_opacity(otf) Also, it might be useful to change the range of the ctf:: ctf.range = [0, 1] """ color = Trait( None, None, TraitTuple(Range(0., 1.), Range(0., 1.), Range(0., 1.)), help="""the color of the vtk object. Overides the colormap, if any, when specified. This is specified as a triplet of float ranging from 0 to 1, eg (1, 1, 1) for white.""", ) vmin = Trait(None, None, CFloat, help="""vmin is used to scale the transparency gradient. If None, the min of the data will be used""") vmax = Trait(None, None, CFloat, help="""vmax is used to scale the transparency gradient. If None, the max of the data will be used""") _target = Instance(modules.Volume, ()) __ctf_rescaled = Bool(False) ###################################################################### # Non-public interface. ###################################################################### def _color_changed(self): if not self.color: return range_min, range_max = self._target.current_range from tvtk.util.ctf import ColorTransferFunction ctf = ColorTransferFunction() try: ctf.range = (range_min, range_max) except Exception: # VTK versions < 5.2 don't seem to need this. pass r, g, b = self.color ctf.add_rgb_point(range_min, r, g, b) ctf.add_rgb_point(range_max, r, g, b) self._target._ctf = ctf self._target._volume_property.set_color(ctf) self._target.update_ctf = True def _vmin_changed(self): vmin = self.vmin vmax = self.vmax range_min, range_max = self._target.current_range if vmin is None: vmin = range_min if vmax is None: vmax = range_max # Change the opacity function from tvtk.util.ctf import PiecewiseFunction, save_ctfs otf = PiecewiseFunction() if range_min < vmin: otf.add_point(range_min, 0.) if range_max > vmax: otf.add_point(range_max, 0.2) otf.add_point(vmin, 0.) otf.add_point(vmax, 0.2) self._target._otf = otf self._target._volume_property.set_scalar_opacity(otf) if self.color is None and not self.__ctf_rescaled and \ ((self.vmin is not None) or (self.vmax is not None)): # FIXME: We don't use 'rescale_ctfs' because it screws up the # nodes. def _rescale_value(x): nx = (x - range_min) / (range_max - range_min) return vmin + nx * (vmax - vmin) # The range of the existing ctf can vary. scale_min, scale_max = self._target._ctf.range def _rescale_node(x): nx = (x - scale_min) / (scale_max - scale_min) return range_min + nx * (range_max - range_min) if hasattr(self._target._ctf, 'nodes'): rgb = list() for value in self._target._ctf.nodes: r, g, b = \ self._target._ctf.get_color(value) rgb.append((_rescale_node(value), r, g, b)) else: rgb = save_ctfs(self._target.volume_property)['rgb'] from tvtk.util.ctf import ColorTransferFunction ctf = ColorTransferFunction() try: ctf.range = (range_min, range_max) except Exception: # VTK versions < 5.2 don't seem to need this. pass rgb.sort() v = rgb[0] ctf.add_rgb_point(range_min, v[1], v[2], v[3]) for v in rgb: ctf.add_rgb_point(_rescale_value(v[0]), v[1], v[2], v[3]) ctf.add_rgb_point(range_max, v[1], v[2], v[3]) self._target._ctf = ctf self._target._volume_property.set_color(ctf) self.__ctf_rescaled = True self._target.update_ctf = True
class A(HasTraits): i = Int r = Range(2, 9223372036854775807)
from ..editor_factory import EditorFactory from ..toolkit import toolkit_object # Currently, this traits is used only for the wx backend. from ..helper import DockStyle # ------------------------------------------------------------------------- # Trait definitions: # ------------------------------------------------------------------------- # Trait whose value is a BaseTraitHandler object handler_trait = Instance(BaseTraitHandler) # The visible number of rows displayed rows_trait = Range(1, 50, 5, desc="the number of list rows to display") # The visible number of columns displayed columns_trait = Range(1, 10, 1, desc="the number of list columns to display") editor_trait = Instance(EditorFactory) # ------------------------------------------------------------------------- # 'ToolkitEditorFactory' class: # ------------------------------------------------------------------------- class ToolkitEditorFactory(EditorFactory): """ Editor factory for list editors. """
class MultiLinePlotDemo(HasTraits): """Demonstrates the MultiLinePlot. This demo assumes that 'model', an instance of DataModel containing the 2D data to be plotted, will be given to the constructor, and will not change later. """ model = Instance(DataModel) plot = Instance(Plot) multi_line_plot_renderer = Instance(MultiLinePlot) # Drives multi_line_plot_renderer.normalized_amplitude amplitude = Range(-1.5, 1.5, value=-0.5) # Drives multi_line_plot_renderer.offset offset = Range(-1.0, 1.0, value=0) traits_view = \ View( VGroup( Group( Item('plot', editor=ComponentEditor(), show_label=False), ), HGroup( Item('amplitude', springy=True), Item('offset', springy=True), springy=True, ), HGroup( Item('object.multi_line_plot_renderer.color', springy=True), Item('object.multi_line_plot_renderer.line_style', springy=True), springy=True, ), ), width=800, height=500, resizable=True, ) #----------------------------------------------------------------------- # Trait defaults #----------------------------------------------------------------------- def _multi_line_plot_renderer_default(self): """Create the default MultiLinePlot instance.""" xs = ArrayDataSource(self.model.x_index, sort_order='ascending') xrange = DataRange1D() xrange.add(xs) ys = ArrayDataSource(self.model.y_index, sort_order='ascending') yrange = DataRange1D() yrange.add(ys) # The data source for the MultiLinePlot. ds = MultiArrayDataSource(data=self.model.data) multi_line_plot_renderer = \ MultiLinePlot( index = xs, yindex = ys, index_mapper = LinearMapper(range=xrange), value_mapper = LinearMapper(range=yrange), value=ds, global_max = self.model.data.max(), global_min = self.model.data.min()) return multi_line_plot_renderer def _plot_default(self): """Create the Plot instance.""" plot = Plot(title="MultiLinePlot Demo") plot.add(self.multi_line_plot_renderer) x_axis = PlotAxis(component=plot, mapper=self.multi_line_plot_renderer.index_mapper, orientation='bottom', title='t (seconds)') y_axis = PlotAxis(component=plot, mapper=self.multi_line_plot_renderer.value_mapper, orientation='left', title='channel') plot.overlays.extend([x_axis, y_axis]) return plot #----------------------------------------------------------------------- # Trait change handlers #----------------------------------------------------------------------- def _amplitude_changed(self, amp): self.multi_line_plot_renderer.normalized_amplitude = amp def _offset_changed(self, off): self.multi_line_plot_renderer.offset = off # FIXME: The change does not trigger a redraw. Force a redraw by # faking an amplitude change. self.multi_line_plot_renderer._amplitude_changed()
class FieldViewer(HasTraits): # 三个轴的取值范围 x0, x1 = Float(-5), Float(5) y0, y1 = Float(-5), Float(5) z0, z1 = Float(-5), Float(5) points = Int(50) # 分割点数 autocontour = Bool(True) # 是否自动计算等值面 v0, v1 = Float(0.0), Float(1.0) # 等值面的取值范围 contour = Range("v0", "v1", 0.5) # 等值面的值 function = Str("x*x*0.5 + y*y + z*z*2.0") # 标量场函数 function_list = [ "x*x*0.5 + y*y + z*z*2.0", "x*y*0.5 + np.sin(2*x)*y +y*z*2.0", "x*y*z", "np.sin((x*x+y*y)/z)" ] plotbutton = Button("描画") scene = Instance(MlabSceneModel, ()) #❶ view = View( HSplit( VGroup( "x0", "x1", "y0", "y1", "z0", "z1", Item('points', label="点数"), Item('autocontour', label="自动等值"), Item('plotbutton', show_label=False), ), VGroup( Item( 'scene', editor=SceneEditor(scene_class=MayaviScene), #❷ resizable=True, height=300, width=350), Item('function', editor=EnumEditor(name='function_list', evaluate=lambda x: x)), Item('contour', editor=RangeEditor(format="%1.2f", low_name="v0", high_name="v1")), show_labels=False)), width=500, resizable=True, title="三维标量场观察器") def _plotbutton_fired(self): self.plot() def plot(self): # 产生三维网格 x, y, z = np.mgrid[ #❸ self.x0:self.x1:1j * self.points, self.y0:self.y1:1j * self.points, self.z0:self.z1:1j * self.points] # 根据函数计算标量场的值 scalars = eval(self.function) #❹ self.scene.mlab.clf() # 清空当前场景 # 绘制等值平面 g = self.scene.mlab.contour3d(x, y, z, scalars, contours=8, transparent=True) #❺ g.contour.auto_contours = self.autocontour self.scene.mlab.axes(figure=self.scene.mayavi_scene) # 添加坐标轴 # 添加一个X-Y的切面 s = self.scene.mlab.pipeline.scalar_cut_plane(g) cutpoint = (self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2, ( self.z0 + self.z1) / 2 s.implicit_plane.normal = (0, 0, 1) # x cut s.implicit_plane.origin = cutpoint self.g = g #❻ self.scalars = scalars # 计算标量场的值的范围 self.v0 = np.min(scalars) self.v1 = np.max(scalars) def _contour_changed(self): #❼ if hasattr(self, "g"): if not self.g.contour.auto_contours: self.g.contour.contours = [self.contour] def _autocontour_changed(self): #❽ if hasattr(self, "g"): self.g.contour.auto_contours = self.autocontour if not self.autocontour: self._contour_changed()
class Visualization(HasTraits): lmax = Range(0, 30, 6) scene = Instance(MlabSceneModel, ()) # the layout of the dialog created view = View( Item('scene', editor=SceneEditor(scene_class=MayaviScene), height=250, width=300, show_label=False), HGroup( '_', 'lmax', ), ) def __init__(self, x, y, z, theta, phi, THETAnorm, PHInorm, mask): # Do not forget to call the parent's __init__ HasTraits.__init__(self) self.theta = theta self.phi = phi self.THETAnorm = THETAnorm self.PHInorm = PHInorm self.z = z self.mask = mask # # Visualize the results. # self.Ximg, self.Yimg = np.mgrid[0:1001, 0:1001] Zimg = estimate(self.z, self.theta, self.phi, self.THETAnorm, self.PHInorm, self.lmax) self.mesh = self.scene.mlab.mesh(self.Ximg, self.Yimg, Zimg, mask=~mask) x_undistort = (theta / (np.pi / 2) * np.sin(phi) + 1) * 1001 / 2 y_undistort = (theta / (np.pi / 2) * np.cos(phi) + 1) * 1001 / 2 self.scene.mlab.points3d(x_undistort, y_undistort, self.z, mode='sphere', scale_mode='none', scale_factor=5, color=(0, 0, 1)) self.scene.mlab.outline(color=(0, 0, 0), extent=(0, 1001, 0, 1001, 0, 255)) @on_trait_change('lmax') def update_plot(self): Zimg = estimate(self.z, self.theta, self.phi, self.THETAnorm, self.PHInorm, self.lmax) Zimg = np.clip(Zimg, 0, 255) self.mesh.mlab_source.set(x=self.Ximg, y=self.Yimg, z=Zimg, mask=~self.mask)
class A(HasTraits): i = Int l = Long r = Range(LONG_TYPE(2), LONG_TYPE(9223372036854775807))
class FileDataSource(Source): # The version of this class. Used for persistence. __version__ = 0 # The list of file names for the timeseries. file_list = List(Str, desc='a list of files belonging to a time series') # The current time step (starts with 0). This trait is a dummy # and is dynamically changed when the `file_list` trait changes. # This is done so the timestep bounds are linked to the number of # the files in the file list. timestep = Range(value=0, low='_min_timestep', high='_max_timestep', enter_set=True, auto_set=False, desc='the current time step') sync_timestep = Bool(False, desc='if all dataset timesteps are synced') play = Bool(False, desc='if timesteps are automatically updated') play_delay = Float(0.2, desc='the delay between loading files') loop = Bool(False, desc='if animation is looped') update_files = Button('Rescan files') base_file_name=Str('', desc="the base name of the file", enter_set=True, auto_set=False, editor=FileEditor()) # A timestep view group that may be included by subclasses. time_step_group = Group( Item(name='file_path', style='readonly'), Group( Item(name='timestep', editor=RangeEditor( low=0, high_name='_max_timestep', mode='slider' ), ), Item(name='sync_timestep'), HGroup( Item(name='play'), Item(name='play_delay', label='Delay'), Item(name='loop'), ), visible_when='len(object.file_list) > 1' ), Item(name='update_files', show_label=False), ) ################################################## # Private traits. ################################################## # The current file name. This is not meant to be touched by the # user. file_path = Instance(FilePath, (), desc='the current file name') _min_timestep = Int(0) _max_timestep = Int(0) _timer = Any _in_update_files = Any(False) ###################################################################### # `object` interface ###################################################################### def __get_pure_state__(self): d = super(FileDataSource, self).__get_pure_state__() # These are obtained dynamically, so don't pickle them. for x in ['file_list', 'timestep', 'play']: d.pop(x, None) return d def __set_pure_state__(self, state): # Use the saved path to initialize the file_list and timestep. fname = state.file_path.abs_pth if not isfile(fname): msg = 'Could not find file at %s\n'%fname msg += 'Please move the file there and try again.' raise IOError(msg) self.initialize(fname) # Now set the remaining state without touching the children. set_state(self, state, ignore=['children', 'file_path']) # Setup the children. handle_children_state(self.children, state.children) # Setup the children's state. set_state(self, state, first=['children'], ignore=['*']) ###################################################################### # `FileDataSource` interface ###################################################################### def initialize(self, base_file_name): """Given a single filename which may or may not be part of a time series, this initializes the list of files. This method need not be called to initialize the data. """ self.base_file_name = base_file_name ###################################################################### # Non-public interface ###################################################################### def _file_list_changed(self, value): # Change the range of the timestep suitably to reflect new list. n_files = len(self.file_list) timestep = max(min(self.timestep, n_files-1), 0) if self.timestep == timestep: self._timestep_changed(timestep) else: self.timestep = timestep self._max_timestep = max(n_files -1, 0) def _file_list_items_changed(self, list_event): self._file_list_changed(self.file_list) def _timestep_changed(self, value): file_list = self.file_list if len(file_list) > 0: self.file_path = FilePath(file_list[value]) else: self.file_path = FilePath('') if self.sync_timestep: for sibling in self._find_sibling_datasets(): sibling.timestep = value def _base_file_name_changed(self, value): self._update_files_fired() try: self.timestep = self.file_list.index(value) except ValueError: self.timestep = 0 def _play_changed(self, value): mm = getattr(self.scene, 'movie_maker', None) if value: if mm is not None: mm.animation_start() self._timer = self._make_play_timer() if not self._timer.IsRunning(): self._timer.Start() else: self._timer.Stop() self._timer = None if mm is not None: mm.animation_stop() def _loop_changed(self, value): if value and self.play: self._play_changed(self.play) def _play_event(self): mm = getattr(self.scene, 'movie_maker', None) nf = self._max_timestep pc = self.timestep pc += 1 if pc > nf: if self.loop: pc = 0 else: self._timer.Stop() pc = nf if mm is not None: mm.animation_stop() if pc != self.timestep: self.timestep = pc if mm is not None: mm.animation_step() def _play_delay_changed(self): if self.play: self._timer.Stop() self._timer.Start(self.play_delay*1000) def _make_play_timer(self): scene = self.scene if scene is None or scene.off_screen_rendering: timer = NoUITimer(self.play_delay*1000, self._play_event) else: from pyface.timer.api import Timer timer = Timer(self.play_delay*1000, self._play_event) return timer def _find_sibling_datasets(self): if self.parent is not None: nt = self._max_timestep return [x for x in self.parent.children if x._max_timestep == nt] else: return [] def _update_files_fired(self): # First get all the siblings before we change the current file list. if self._in_update_files: return try: self._in_update_files = True if self.sync_timestep: siblings = self._find_sibling_datasets() else: siblings = [] fname = self.base_file_name file_list = get_file_list(fname) if len(file_list) == 0: file_list = [fname] self.file_list = file_list for sibling in siblings: sibling.update_files = True finally: self._in_update_files = False
class Catenary(HasTraits): N = Int(31) dump = Range(0.0, 0.5, 0.1) k = Range(1.0, 100.0, 20.0) length = Range(1.0, 3.0, 1.0) g = Range(0.0, 0.1, 0.01) t = Float(0.0) x = Property() y = Property() def __init__(self, *args, **kw): super(Catenary, self).__init__(*args, **kw) x0 = np.linspace(0, 1, self.N) y0 = np.zeros_like(x0) vx0 = np.zeros_like(x0) vy0 = np.zeros_like(x0) self.status0 = np.r_[x0, y0, vx0, vy0] def diff_status(self, t, status): x, y, vx, vy = status.reshape(4, -1) dvx = np.zeros_like(x) dvy = np.zeros_like(x) dx = vx dy = vy s = np.s_[1:-1] l = self.length / (self.N - 1) k = self.k g = self.g dump = self.dump l1 = np.sqrt((x[s] - x[:-2])**2 + (y[s] - y[:-2])**2) l2 = np.sqrt((x[s] - x[2:])**2 + (y[s] - y[2:])**2) dl1 = (l1 - l) / l1 dl2 = (l2 - l) / l2 dvx[s] = -(x[s] - x[:-2]) * k * dl1 - (x[s] - x[2:]) * k * dl2 dvy[s] = -(y[s] - y[:-2]) * k * dl1 - (y[s] - y[2:]) * k * dl2 + g dvx[s] -= vx[s] * dump dvy[s] -= vy[s] * dump return np.r_[dx, dy, dvx, dvy] def ode_init(self): self.t = 0 self.system = integrate.ode(self.diff_status) self.system.set_integrator("vode", method="bdf") self.system.set_initial_value(self.status0, 0) self.status = self.status0 def ode_step(self, dt): self.system.integrate(self.t + dt) self.t = self.system.t self.status = self.system.y def _get_x(self): return self.status[:self.N] def _get_y(self): return self.status[self.N:self.N * 2]
class ArcticDB(WritableFactorDB): """ArcticDB""" DBName = Str("arctic", arg_type="String", label="数据库名", order=0) IPAddr = Str("127.0.0.1", arg_type="String", label="IP地址", order=1) Port = Range(low=0, high=65535, value=27017, arg_type="Integer", label="端口", order=2) User = Str("", arg_type="String", label="用户名", order=3) Pwd = Password("", arg_type="String", label="密码", order=4) def __init__(self, sys_args={}, config_file=None, **kwargs): self._Arctic = None # Arctic 对象 super().__init__(sys_args=sys_args, config_file=(__QS_ConfigPath__ + os.sep + "ArcticDBConfig.json" if config_file is None else config_file), **kwargs) self.Name = "ArcticDB" return def __getstate__(self): state = self.__dict__.copy() # Remove the unpicklable entries. state["_Arctic"] = self.isAvailable() return state def __setstate__(self, state): super().__setstate__(state) if self._Arctic: self.connect() else: self._Arctic = None def connect(self): self._Arctic = arctic.Arctic(self.IPAddr) return 0 def disconnect(self): self._Arctic = None return 1 def isAvailable(self): return (self._Arctic is not None) @property def TableNames(self): return sorted(self._Arctic.list_libraries()) def getTable(self, table_name, args={}): if table_name not in self._Arctic.list_libraries(): raise __QS_Error__("表 '%s' 不存在!" % table_name) return _FactorTable(name=table_name, fdb=self, sys_args=args) def renameTable(self, old_table_name, new_table_name): self._Arctic.rename_library(old_table_name, new_table_name) return 0 def deleteTable(self, table_name): self._Arctic.delete_library(table_name) return 0 def setTableMetaData(self, table_name, key=None, value=None, meta_data=None): Lib = self._Arctic[table_name] TableInfo = Lib.read_metadata("_FactorInfo") if TableInfo is None: TableInfo = {} if meta_data is not None: TableInfo.update(dict(meta_data)) if key is not None: TableInfo[key] = value Lib.write_metadata("_FactorInfo", TableInfo) return 0 def renameFactor(self, table_name, old_factor_name, new_factor_name): if table_name not in self._Arctic.list_libraries(): raise __QS_Error__("表: '%s' 不存在!" % table_name) Lib = self._Arctic[table_name] FactorInfo = Lib.read(symbol="_FactorInfo").set_index(["FactorName"]) if old_factor_name not in FactorInfo.index: raise __QS_Error__("因子: '%s' 不存在!" % old_factor_name) if new_factor_name in FactorInfo.index: raise __QS_Error__("因子: '%s' 已经存在!" % new_factor_name) FactorNames = FactorInfo.index.tolist() FactorNames[FactorNames.index(old_factor_name)] = new_factor_name FactorInfo.index = FactorNames FactorInfo.index.name = "FactorName" Lib.write( "_FactorInfo", FactorInfo.reset_index(), chunker=arctic.chunkstore.passthrough_chunker.PassthroughChunker()) IDs = Lib.list_symbols() IDs.remove("_FactorInfo") for iID in IDs: iMetaData = Lib.read_metadata(iID) if old_factor_name in iMetaData["FactorNames"]: iMetaData["FactorNames"][iMetaData["FactorNames"].index( old_factor_name)] = new_factor_name Lib.write_metadata(iID, iMetaData) return 0 def deleteFactor(self, table_name, factor_names): if table_name not in self._Arctic.list_libraries(): return 0 Lib = self._Arctic[table_name] FactorInfo = Lib.read(symbol="_FactorInfo").set_index(["FactorName"]) FactorInfo = FactorInfo.loc[FactorInfo.index.difference(factor_names)] if FactorInfo.shape[0] == 0: return self.deleteTable(table_name) IDs = Lib.list_symbols() IDs.remove("_FactorInfo") for iID in IDs: iMetaData = Lib.read_metadata(iID) iFactorIndex = pd.Series(iMetaData["Cols"], index=iMetaData["FactorNames"]) iFactorIndex = iFactorIndex[iFactorIndex.index.difference( factor_names)] if iFactorIndex.shape[0] == 0: Lib.delete(iID) continue iFactorNames = iFactorIndex.values.tolist() iData = Lib.read(symbol=iID, columns=iFactorNames) iCols = [str(i) for i in range(iFactorIndex.shape[0])] iData.columns = iCols iMetaData["FactorNames"], iMetaData["Cols"] = iFactorNames, iCols Lib.write(iID, iData, metadata=iMetaData) Lib.write( "_FactorInfo", FactorInfo.reset_index(), chunker=arctic.chunkstore.passthrough_chunker.PassthroughChunker()) return 0 def setFactorMetaData(self, table_name, ifactor_name, key=None, value=None, meta_data=None): if (key is None) and (meta_data is None): return 0 Lib = self._Arctic[table_name] FactorInfo = Lib.read(symbol="_FactorInfo").set_index(["FactorName"]) if key is not None: FactorInfo.loc[ifactor_name, key] = value if meta_data is not None: for iKey in meta_data: FactorInfo.loc[ifactor_name, iKey] = meta_data[iKey] Lib.write( "_FactorInfo", FactorInfo.reset_index(), chunker=arctic.chunkstore.passthrough_chunker.PassthroughChunker()) return 0 def writeData(self, data, table_name, if_exists="update", data_type={}, **kwargs): if data.shape[0] == 0: return 0 if table_name not in self._Arctic.list_libraries(): return self._writeNewData(data, table_name, data_type=data_type) Lib = self._Arctic[table_name] DataCols = [str(i) for i in range(data.shape[0])] #DTRange = pd.date_range(data.major_axis[0], data.major_axis[-1], freq=Freq) DTRange = data.major_axis OverWrite = (if_exists == "update") for i, iID in enumerate(data.minor_axis): iData = data.iloc[:, :, i] if not Lib.has_symbol(iID): iMetaData = { "FactorNames": iData.columns.tolist(), "Cols": DataCols } iData.index.name, iData.columns = "date", DataCols Lib.write(iID, iData, metadata=iMetaData) continue iMetaData = Lib.read_metadata(symbol=iID) iOldFactorNames, iCols = iMetaData["FactorNames"], iMetaData[ "Cols"] iNewFactorNames = iData.columns.difference( iOldFactorNames).tolist() #iCrossFactorNames = iOldFactorNames.intersection(iData.columns).tolist() iOldData = Lib.read(symbol=iID, chunk_range=DTRange, filter_data=True) if iOldData.shape[0] > 0: iOldData.columns = iOldFactorNames iOldData = iOldData.loc[iOldData.index.union(iData.index), iOldFactorNames + iNewFactorNames] iOldData.update(iData, overwrite=OverWrite) else: iOldData = iData.loc[:, iOldFactorNames + iNewFactorNames] if iNewFactorNames: iCols += [ str(i) for i in range(iOldData.shape[1], iOldData.shape[1] + len(iNewFactorNames)) ] #iOldData = pd.merge(iOldData, iData.loc[:, iNewFactorNames], how="outer", left_index=True, right_index=True) #if iCrossFactorNames: #iOldData = iOldData.loc[iOldData.index.union(iData.index), :] #iOldData.update(iData, overwrite=OverWrite) #if if_exists=="update": iOldData.loc[iData.index, iCrossFactorNames] = iData.loc[:, iCrossFactorNames] #else: iOldData.loc[iData.index, iCrossFactorNames] = iOldData.loc[iData.index, iCrossFactorNames].where(pd.notnull(iOldData.loc[iData.index, iCrossFactorNames]), iData.loc[:, iCrossFactorNames]) iOldData.index.name, iOldData.columns = "date", iCols iMetaData["FactorNames"], iMetaData[ "Cols"] = iOldFactorNames + iNewFactorNames, iCols Lib.update(iID, iOldData, metadata=iMetaData, chunk_range=DTRange) FactorInfo = Lib.read(symbol="_FactorInfo").set_index("FactorName") NewFactorNames = data.items.difference(FactorInfo.index).tolist() FactorInfo = FactorInfo.loc[FactorInfo.index.tolist() + NewFactorNames, :] for iFactorName in NewFactorNames: if iFactorName in data_type: FactorInfo.loc[iFactorName, "DataType"] = data_type[iFactorName] elif np.dtype('O') in data.loc[iFactorName].dtypes: FactorInfo.loc[iFactorName, "DataType"] = "string" else: FactorInfo.loc[iFactorName, "DataType"] = "double" Lib.write( "_FactorInfo", FactorInfo.reset_index(), chunker=arctic.chunkstore.passthrough_chunker.PassthroughChunker()) return 0 def _writeNewData(self, data, table_name, data_type): FactorNames = data.items.tolist() DataType = pd.Series("double", index=data.items) for i, iFactorName in enumerate(DataType.index): if iFactorName in data_type: DataType.iloc[i] = data_type[iFactorName] elif np.dtype('O') in data.iloc[i].dtypes: DataType.iloc[i] = "string" DataCols = [str(i) for i in range(data.shape[0])] data.items = DataCols self._Arctic.initialize_library(table_name, lib_type=arctic.CHUNK_STORE) Lib = self._Arctic[table_name] for i, iID in enumerate(data.minor_axis): iData = data.iloc[:, :, i] iMetaData = {"FactorNames": FactorNames, "Cols": DataCols} iData.index.name = "date" Lib.write(iID, iData, metadata=iMetaData) DataType = DataType.reset_index() DataType.columns = ["FactorName", "DataType"] Lib.write( "_FactorInfo", DataType, chunker=arctic.chunkstore.passthrough_chunker.PassthroughChunker()) data.items = FactorNames return 0