class BaseImageController(ControllerBase): plot = Instance(BasePlotContainer) plotdata = Instance(ArrayPlotData) def __init__(self, parent, treasure_chest=None, data_path='/rawdata', *args, **kw): super(BaseImageController, self).__init__(parent, treasure_chest, data_path, *args, **kw) self.plotdata = ArrayPlotData() self._can_save = True self._can_change_idx = True def init_plot(self): self.plotdata.set_data('imagedata', self.get_active_image()) self.plot = self.get_simple_image_plot(array_plot_data = self.plotdata, title = self.get_active_name() ) def data_updated(self): # reinitialize data self.__init__(parent = self.parent, treasure_chest=self.chest, data_path=self.data_path) # this is a 2D image for plotting purposes def get_active_image(self): nodes = self.chest.list_nodes('/rawdata') if len(nodes) > 0: return nodes[self.selected_index][:] def get_active_name(self): nodes = self.chest.list_nodes('/rawdata') return nodes[self.selected_index].name @on_trait_change("selected_index") def update_image(self): if self.chest is None or self.numfiles<1: return # get the old image for the sake of comparing image sizes old_data = self.plotdata.get_data('imagedata') active_image = self.get_active_image() self.plotdata.set_data("imagedata", active_image) self.set_plot_title(self.get_active_name()) if old_data.shape != active_image.shape: grid_data_source = self._base_plot.range2d.sources[0] grid_data_source.set_data(np.arange(active_image.shape[1]), np.arange(active_image.shape[0])) self.plot = self.get_simple_image_plot(array_plot_data = self.plotdata, title = self.get_active_name()) self.plot.aspect_ratio=(float(active_image.shape[1])/active_image.shape[0]) def open_save_UI(self, plot_id='plot'): save_controller = SaveFileController(plot=self.get_plot(plot_id), parent=self) save_dialog = simple_session('save', 'Save dialog', SavePlotDialog, controller=save_controller) Application.instance().add_factories([save_dialog]) session_id = Application.instance().start_session('save') save_controller._session_id = session_id
class ChacoReporter(StateDataReporter, HasTraits): plots = Instance(VPlotContainer) labels = List traits_view = View( Group(Item('plots', editor=ComponentEditor(), show_label=False)), width=800, height=600, resizable=True, title='OpenMM') def construct_plots(self): """Build the Chaco Plots. This will be run on the first report """ self.labels = super(ChacoReporter, self)._headers() self.plots = VPlotContainer(resizable="hv", bgcolor="lightgray", fill_padding=True, padding=10) # this looks cryptic, but it is equivalent to # ArrayPlotData(a=[], b=[], c=[]) # if the keys are a,b,c. This just does it for all of the keys. self.plotdata = ArrayPlotData(**dict(zip(self.labels, [[]]*len(self.labels)))) # figure out which key will be the x axis x = None x_labels = ['Time (ps)', 'Step'] for possible_x in x_labels: if possible_x in self.labels: x = possible_x break if x is None: raise ValueError('The reporter published neither the step nor time' 'count, so I don\'t know what to plot on the x-axis!') colors = itertools.cycle(['blue', 'green', 'silver', 'pink', 'lightblue', 'red', 'darkgray', 'lightgreen']) for y in set(self.labels).difference(x_labels): self.plots.add(chaco_scatter(self.plotdata, x_name=x, y_name=y, color=colors.next())) def _constructReportValues(self, simulation, state): values = super(ChacoReporter, self)._constructReportValues(simulation, state) for i, label in enumerate(self.labels): current = self.plotdata.get_data(label) self.plotdata.set_data(label, np.r_[current, float(values[i])]) return values def report(self, simulation, state): if not self._hasInitialized: self.construct_plots() super(ChacoReporter, self).report(simulation, state)
def test_data_changed_events(self): # Test data. grumpy = numpy.ones((3, 4)) grumpy_too = numpy.zeros(16) plot_data = ArrayPlotData() with self.monitor_events(plot_data) as events: plot_data.set_data('Grumpy', grumpy) self.assertEqual(events, [{'added': ['Grumpy']}]) # While we're here, check that get_data works as advertised. grumpy_out = plot_data.get_data('Grumpy') self.assertIs(grumpy_out, grumpy) with self.monitor_events(plot_data) as events: plot_data.set_data('Grumpy', grumpy_too) self.assertEqual(events, [{'changed': ['Grumpy']}]) with self.monitor_events(plot_data) as events: plot_data.del_data('Grumpy') self.assertEqual(events, [{'removed': ['Grumpy']}])
def test_data_changed_events(self): # Test data. grumpy = numpy.ones((3, 4)) grumpy_too = numpy.zeros(16) plot_data = ArrayPlotData() with self.monitor_events(plot_data) as events: plot_data.set_data("Grumpy", grumpy) self.assertEqual(events, [{"added": ["Grumpy"]}]) # While we're here, check that get_data works as advertised. grumpy_out = plot_data.get_data("Grumpy") self.assertIs(grumpy_out, grumpy) with self.monitor_events(plot_data) as events: plot_data.set_data("Grumpy", grumpy_too) self.assertEqual(events, [{"changed": ["Grumpy"]}]) with self.monitor_events(plot_data) as events: plot_data.del_data("Grumpy") self.assertEqual(events, [{"removed": ["Grumpy"]}])
class TwoDimensionalPlot(ChacoPlot): """ A 2D plot. """ auto_color_idx = 0 auto_color_list = ['green', 'brown', 'blue', 'red', 'black'] @classmethod def auto_color(cls): """ Choose the next color. """ color = cls.auto_color_list[cls.auto_color_idx] cls.auto_color_idx = (cls.auto_color_idx + 1) % len( cls.auto_color_list) return color def __init__(self, parent, color=None, *args, **kwargs): self.parent = parent if color is None: color = self.auto_color() self.data = ArrayPlotData() self.data.set_data('x', [0]) self.data.set_data('y', [0]) ChacoPlot.__init__(self, self.data, *args, **kwargs) self.plot(('x', 'y'), color=color) self.configure() @property def control(self): """ A drawable control. """ return Window(self.parent, component=self).control def get_data(self, axis): """ Values for an axis. """ return self.data.get_data(axis) def set_data(self, values, axis): self.data.set_data(axis, values) x_data = property(partial(get_data, axis='x'), partial(set_data, axis='x')) y_data = property(partial(get_data, axis='y'), partial(set_data, axis='y')) def x_autoscale(self): """ Enable autoscaling for the x axis. """ x_range = self.plots.values()[0][0].index_mapper.range x_range.low = x_range.high = 'auto' def y_autoscale(self): """ Enable autoscaling for the y axis. """ y_range = self.plots.values()[0][0].value_mapper.range y_range.low = y_range.high = 'auto'
class OpenMMScriptRunner(HasTraits): plots = Instance(VPlotContainer) plots_created = Bool openmm_script_code = String status = String traits_view = View( Group( HGroup(spring, Item('status', style='readonly'), spring), Item('plots', editor=ComponentEditor(), show_label=False) ), width=800, height=600, resizable=True, title='OpenMM Script Runner' ) def __init__(self, **traits): super(OpenMMScriptRunner, self).__init__(**traits) self._plots_created = False q = Queue.Queue() # start up two threads. the first, t1, will run the script # and place the statedata into the queue # the second will remove elements from the queue and update the # plots in the UI t1 = threading.Thread(target=run_openmm_script, args=(self.openmm_script_code, q)) t2 = threading.Thread(target=self.queue_consumer, args=(q,)) t1.start() t2.start() def queue_consumer(self, q): """Main loop for a thread that consumes the messages from the queue and plots them""" self.status = 'Running...' while True: try: msg = q.get_nowait() if msg is None: break self.update_plot(msg) except Queue.Empty: time.sleep(0.1) self.status = 'Done' def create_plots(self, keys): """Create the plots Paramters --------- keys : list of strings A list of all of the keys in the msg dict. This should be something like ['Step', 'Temperature', 'Potential Energy']. We'll create the ArrayPlotData container in which each of these timeseries will get put. """ self.plots = VPlotContainer(resizable = "hv", bgcolor="lightgray", fill_padding=True, padding = 10) # this looks cryptic, but it is equivalent to # ArrayPlotData(a=[], b=[], c=[]) # if the keys are a,b,c. This just does it for all of the keys. self.plotdata = ArrayPlotData(**dict(zip(keys, [[]]*len(keys)))) # figure out which key will be the x axis if 'Step' in keys: x = 'Step' elif 'Time (ps)' in keys: x = 'Time (ps)' else: raise ValueError('The reporter published neither the step nor time' 'count, so I don\'t know what to plot on the x-axis!') colors = itertools.cycle(['blue', 'green', 'silver', 'pink', 'lightblue', 'red', 'darkgray', 'lightgreen',]) for y in filter(lambda y: y != x, keys): self.plots.add(chaco_scatter(self.plotdata, x_name=x, y_name=y, color=colors.next())) def update_plot(self, msg): """Add data points from the message to the plots Paramters --------- msg : dict This is the message sent over the Queue from the script """ if not self.plots_created: self.create_plots(msg.keys()) self.plots_created = True for k, v in msg.iteritems(): current = self.plotdata.get_data(k) self.plotdata.set_data(k, np.r_[current, v])
class OneDViewer(BaseViewer): """ This class just contains the two data arrays that will be updated by the Controller. The visualization/editor for this class is a Chaco plot. """ #mode = Enum(['Rolling','Replace']) #positions = Array() #max_num_points = Int(1000) ndim = Int(1) num_ticks = Int(0) resolution = Float(1.) start_pos = Float(0.) csr_pos = Array pd = Instance(ArrayPlotData,transient=True) plot = Any() #data = Array() #xbounds = Property(property_depends_on='index, start_pos, resolution') traits_view = View( Item('plot', editor=ComponentEditor(), show_label=False), #HGroup(spring, Item("plot_type", style='custom'), spring), resizable=True, ) def __init__(self, *args,**kargs): super(OneDViewer, self).__init__( *args,**kargs) self.create_plot_element() def set_data(self, new_data, idx=None): if idx: self.data[idx] = new_data self.csr_pos = idx else: self.data = new_data self.csr_pos = np.array([0]) self.refresh() def _data_default(self): return np.full((self.max_size,), np.nan) def _csr_pos_default(self): return np.array([0]) def create_plot_element(self): self.pd = ArrayPlotData(x=np.arange(self.data.size), y=self.data, posx=self.csr_pos, posy=np.array([self.data[self.csr_pos]])) plot = Plot(self.pd) plot.plot(("x", "y"), #type_trait="plot_type", #type='line_scatter_1d', #resizable='', title='', #x_label="Time", y_label="Signal", color=tuple(cbrewer[np.random.randint(0,10)]), bgcolor="grey", border_visible=True, border_width=1, #padding_bg_color="lightgray", width=800, height=380, marker_size=2, show_label=False) plot.plot_1d("posx", #type="scatter", type="line_scatter_1d", name="dot", color="red", #color_mapper=self._cmap(image_value_range), marker="circle", marker_size=4) self.pd.set_data('x', np.arange(self.data.size)) self.pd.set_data('y', self.data) self.pd.set_data('posx', self.csr_pos) self.pd.set_data('posy', np.array([self.data[self.csr_pos]])) self.plot = plot def refresh(self): if self.data.size == self.pd.get_data('y').size: self.pd.set_data('x', np.arange(self.data.size)) self.pd.set_data('y', self.data) self.pd.set_data('posx', self.csr_pos) self.pd.set_data('posy', np.array([self.data[self.csr_pos]])) else: self.create_plot_element() # def _data_changed(self,old, new): # if old is None: # self.create_plot_element() # return # if new is None: # return # # if new.shape==old.shape: # self.pd.set_data('y', self.data) # else: # self.create_plot_element() #def _plot_default(self): #return self.create_plot_element() # @on_trait_change('data, resolution, start_pos') # def update_positions(self): # if self.data is not None: # return np.linspace(self.start_pos, self.start_pos+self.data.size * self.resolution, self.data.size) # else: # return (0.,1.) @property_depends_on('data[]') def _get_positions(self): if self.data is not None: return np.arange(self.data.size) else: return np.array([0,1])
class ColormappedPlot(ChacoPlot): """ A colormapped plot. """ def __init__(self, parent, x_bounds, y_bounds, *args, **kwargs): self.parent = parent self.data = ArrayPlotData() self.data.set_data('color', [[0]]) ChacoPlot.__init__(self, self.data, *args, **kwargs) self.img_plot('color', colormap=jet, xbounds=x_bounds, ybounds=y_bounds) self.configure() @property def plot_obj(self): """ The actual plot object. """ return self.plots.values()[0][0] @property def control(self): """ A drawable control with a color bar. """ color_map = self.plot_obj.color_mapper linear_mapper = LinearMapper(range=color_map.range) color_bar = ColorBar(index_mapper=linear_mapper, color_mapper=color_map, plot=self.plot_obj, orientation='v', resizable='v', width=30) color_bar._axis.tick_label_formatter = self.sci_formatter color_bar.padding_top = self.padding_top color_bar.padding_bottom = self.padding_bottom color_bar.padding_left = 50 # Room for labels. color_bar.padding_right = 10 range_selection = RangeSelection(component=color_bar) range_selection.listeners.append(self.plot_obj) color_bar.tools.append(range_selection) range_selection_overlay = RangeSelectionOverlay(component=color_bar) color_bar.overlays.append(range_selection_overlay) container = HPlotContainer(use_backbuffer=True) container.add(self) container.add(color_bar) return Window(self.parent, component=container).control @property def color_data(self): """ Plotted values. """ return self.data.get_data('color') @color_data.setter def color_data(self, values): self.data.set_data('color', values) @property def low_setting(self): """ Lowest color value. """ return self.plot_obj.color_mapper.range.low @low_setting.setter def low_setting(self, value): self.plot_obj.color_mapper.range.low_setting = value @property def high_setting(self): """ Highest color value. """ return self.plot_obj.color_mapper.range.high @high_setting.setter def high_setting(self, value): self.plot_obj.color_mapper.range.high_setting = value
class Plot2D(DataView): #------------------------------------------------------------------------ # Data-related traits #------------------------------------------------------------------------ # The PlotData instance that drives this plot. data = Instance(AbstractPlotData) # Mapping of data names from self.data to their respective datasources. datasources = Dict(Str, Instance(AbstractDataSource)) #------------------------------------------------------------------------ # General plotting traits #------------------------------------------------------------------------ # Mapping of plot names to *lists* of plot renderers. plots = Dict(Str, List) index2d = Instance(GridDataSource) # Optional mapper for the color axis. Not instantiated until first use; # destroyed if no color plots are on the plot. color_mapper = Instance(AbstractColormap) # Mapping of renderer type string to renderer class # This can be overriden to customize what renderer type the Plot # will instantiate for its various plotting methods. renderer_map = Dict(dict(img_plot = ImagePlot, cmap_img_plot = CMapImagePlot, contour_line_plot = ContourLinePlot, contour_poly_plot = ContourPolyPlot, )) #------------------------------------------------------------------------ # Annotations and decorations #------------------------------------------------------------------------ # The legend on the plot. legend = Instance(Legend) # Convenience attribute for legend.align; can be "ur", "ul", "ll", "lr". legend_alignment = Property #------------------------------------------------------------------------ # Public methods #------------------------------------------------------------------------ def __init__(self, data=None, **kwtraits): if 'origin' in kwtraits: self.default_origin = kwtraits.pop('origin') if 'bgcolor' not in kwtraits: kwtraits['bgcolor'] = 'black' super(Plot2D, self).__init__(**kwtraits) if data is not None: if isinstance(data, AbstractPlotData): self.data = data elif type(data) in (ndarray, tuple, list): self.data = ArrayPlotData(data) else: raise ValueError, "Don't know how to create PlotData for data" \ "of type " + str(type(data)) if not self.legend: self.legend = Legend(visible=False, align="ur", error_icon="blank", padding=10, component=self) # ensure that we only get displayed once by new_window() self._plot_ui_info = None return def img_plot(self, data, name=None, colormap=None, xbounds=None, ybounds=None, origin=None, hide_grids=True, **styles): """ Adds image plots to this Plot object. If *data* has shape (N, M, 3) or (N, M, 4), then it is treated as RGB or RGBA (respectively) and *colormap* is ignored. If *data* is an array of floating-point data, then a colormap can be provided via the *colormap* argument, or the default of 'Spectral' will be used. *Data* should be in row-major order, so that xbounds corresponds to *data*'s second axis, and ybounds corresponds to the first axis. Parameters ========== data : string The name of the data array in self.plot_data name : string The name of the plot; if omitted, then a name is generated. xbounds, ybounds : string, tuple, or ndarray Bounds where this image resides. Bound may be: a) names of data in the plot data; b) tuples of (low, high) in data space, c) 1D arrays of values representing the pixel boundaries (must be 1 element larger than underlying data), or d) 2D arrays as obtained from a meshgrid operation origin : string Which corner the origin of this plot should occupy: "bottom left", "top left", "bottom right", "top right" hide_grids : bool, default True Whether or not to automatically hide the grid lines on the plot styles : series of keyword arguments Attributes and values that apply to one or more of the plot types requested, e.g.,'line_color' or 'line_width'. """ if name is None: name = self._make_new_plot_name() if origin is None: origin = self.default_origin value = self._get_or_create_datasource(data) array_data = value.get_data() if len(array_data.shape) == 3: if array_data.shape[2] not in (3,4): raise ValueError("Image plots require color depth of 3 or 4.") cls = self.renderer_map["img_plot"] kwargs = dict(**styles) else: if colormap is None: if self.color_mapper is None: colormap = Spectral(DataRange1D(value)) else: colormap = self.color_mapper elif isinstance(colormap, AbstractColormap): if colormap.range is None: colormap.range = DataRange1D(value) else: colormap = colormap(DataRange1D(value)) self.color_mapper = colormap cls = self.renderer_map["cmap_img_plot"] kwargs = dict(value_mapper=colormap, **styles) return self._create_2d_plot(cls, name, origin, xbounds, ybounds, value, hide_grids, **kwargs) def contour_plot(self, data, type="line", name=None, poly_cmap=None, xbounds=None, ybounds=None, origin=None, hide_grids=True, **styles): """ Adds contour plots to this Plot object. Parameters ========== data : string The name of the data array in self.plot_data, which must be floating point data. type : comma-delimited string of "line", "poly" The type of contour plot to add. If the value is "poly" and no colormap is provided via the *poly_cmap* argument, then a default colormap of 'Spectral' is used. name : string The name of the plot; if omitted, then a name is generated. poly_cmap : string The name of the color-map function to call (in chaco.default_colormaps) or an AbstractColormap instance to use for contour poly plots (ignored for contour line plots) xbounds, ybounds : string, tuple, or ndarray Bounds where this image resides. Bound may be: a) names of data in the plot data; b) tuples of (low, high) in data space, c) 1D arrays of values representing the pixel boundaries (must be 1 element larger than underlying data), or d) 2D arrays as obtained from a meshgrid operation origin : string Which corner the origin of this plot should occupy: "bottom left", "top left", "bottom right", "top right" hide_grids : bool, default True Whether or not to automatically hide the grid lines on the plot styles : series of keyword arguments Attributes and values that apply to one or more of the plot types requested, e.g.,'line_color' or 'line_width'. """ if name is None: name = self._make_new_plot_name() if origin is None: origin = self.default_origin value = self._get_or_create_datasource(data) if value.value_depth != 1: raise ValueError("Contour plots require 2D scalar field") if type == "line": cls = self.renderer_map["contour_line_plot"] kwargs = dict(**styles) # if colors is given as a factory func, use it to make a # concrete colormapper. Better way to do this? if "colors" in kwargs: cmap = kwargs["colors"] if isinstance(cmap, FunctionType): kwargs["colors"] = cmap(DataRange1D(value)) elif getattr(cmap, 'range', 'dummy') is None: cmap.range = DataRange1D(value) elif type == "poly": if poly_cmap is None: poly_cmap = Spectral(DataRange1D(value)) elif isinstance(poly_cmap, FunctionType): poly_cmap = poly_cmap(DataRange1D(value)) elif getattr(poly_cmap, 'range', 'dummy') is None: poly_cmap.range = DataRange1D(value) cls = self.renderer_map["contour_poly_plot"] kwargs = dict(color_mapper=poly_cmap, **styles) else: raise ValueError("Unhandled contour plot type: " + type) return self._create_2d_plot(cls, name, origin, xbounds, ybounds, value, hide_grids, **kwargs) def _process_2d_bounds(self, bounds, array_data, axis): """Transform an arbitrary bounds definition into a linspace. Process all the ways the user could have defined the x- or y-bounds of a 2d plot and return a linspace between the lower and upper range of the bounds. Parameters ---------- bounds : any User bounds definition array_data : 2D array The 2D plot data axis : int The axis along which the bounds are tyo be set """ num_ticks = array_data.shape[axis] + 1 if bounds is None: return arange(num_ticks) if type(bounds) is tuple: # create a linspace with the bounds limits return linspace(bounds[0], bounds[1], num_ticks) if type(bounds) is ndarray and len(bounds.shape) == 1: # bounds is 1D, but of the wrong size if len(bounds) != num_ticks: msg = ("1D bounds of an image plot needs to have 1 more " "element than its corresponding data shape, because " "they represent the locations of pixel boundaries.") raise ValueError(msg) else: return linspace(bounds[0], bounds[-1], num_ticks) if type(bounds) is ndarray and len(bounds.shape) == 2: # bounds is 2D, assumed to be a meshgrid # This is triggered when doing something like # >>> xbounds, ybounds = meshgrid(...) # >>> z = f(xbounds, ybounds) if bounds.shape != array_data.shape: msg = ("2D bounds of an image plot needs to have the same " "shape as the underlying data, because " "they are assumed to be generated from meshgrids.") raise ValueError(msg) else: if axis == 0: bounds = bounds[:,0] else: bounds = bounds[0,:] interval = bounds[1] - bounds[0] return linspace(bounds[0], bounds[-1]+interval, num_ticks) raise ValueError("bounds must be None, a tuple, an array, " "or a PlotData name") def _create_2d_plot(self, cls, name, origin, xbounds, ybounds, value_ds, hide_grids, **kwargs): if name is None: name = self._make_new_plot_name() if origin is None: origin = self.default_origin array_data = value_ds.get_data() #~ # process bounds to get linspaces if isinstance(xbounds, basestring): xbounds = self._get_or_create_datasource(xbounds).get_data() xs = self._process_2d_bounds(xbounds, array_data, 1) if isinstance(ybounds, basestring): ybounds = self._get_or_create_datasource(ybounds).get_data() ys = self._process_2d_bounds(ybounds, array_data, 0) # Create the index and add its datasources to the appropriate ranges self.index = GridDataSource(xs, ys, sort_order=('ascending', 'ascending')) self.range2d.add(self.index) mapper = GridMapper(range=self.range2d, stretch_data_x=self.x_mapper.stretch_data, stretch_data_y=self.y_mapper.stretch_data) plot = cls(index=self.index, value=value_ds, index_mapper=mapper, orientation=self.orientation, origin=origin, **kwargs) if hide_grids: self.x_grid.visible = False self.y_grid.visible = False self.add(plot) self.plots[name] = [plot] return self.plots[name] def delplot(self, *names): """ Removes the named sub-plots. """ # This process involves removing the plots, then checking the index range # and value range for leftover datasources, and removing those if necessary. # Remove all the renderers from us (container) and create a set of the # datasources that we might have to remove from the ranges deleted_sources = set() for renderer in itertools.chain(*[self.plots.pop(name) for name in names]): self.remove(renderer) deleted_sources.add(renderer.index) deleted_sources.add(renderer.value) # Cull the candidate list of sources to remove by checking the other plots sources_in_use = set() for p in itertools.chain(*self.plots.values()): sources_in_use.add(p.index) sources_in_use.add(p.value) unused_sources = deleted_sources - sources_in_use - set([None]) # Remove the unused sources from all ranges for source in unused_sources: if source.index_dimension == "scalar": # Try both index and range, it doesn't hurt self.index_range.remove(source) self.value_range.remove(source) elif source.index_dimension == "image": self.range2d.remove(source) else: warnings.warn("Couldn't remove datasource from datarange.") return def hideplot(self, *names): """ Convenience function to sets the named plots to be invisible. Their renderers are not removed, and they are still in the list of plots. """ for renderer in itertools.chain(*[self.plots[name] for name in names]): renderer.visible = False return def showplot(self, *names): """ Convenience function to sets the named plots to be visible. """ for renderer in itertools.chain(*[self.plots[name] for name in names]): renderer.visible = True return #------------------------------------------------------------------------ # Private methods #------------------------------------------------------------------------ def _make_new_plot_name(self): """ Returns a string that is not already used as a plot title. """ n = len(self.plots) plot_template = "plot%d" while 1: name = plot_template % n if name not in self.plots: break else: n += 1 return name def _get_or_create_datasource(self, name, sort_order = 'none'): """ Returns the data source associated with the given name, or creates it if it doesn't exist. """ if name not in self.datasources: data = self.data.get_data(name) if type(data) in (list, tuple): data = array(data) if isinstance(data, ndarray): if len(data.shape) == 1: ds = ArrayDataSource(data, sort_order=sort_order) elif len(data.shape) == 2: ds = ImageData(data=data, value_depth=1) elif len(data.shape) == 3: if data.shape[2] in (3,4): ds = ImageData(data=data, value_depth=int(data.shape[2])) else: raise ValueError("Unhandled array shape in creating new plot: " \ + str(data.shape)) elif isinstance(data, AbstractDataSource): ds = data else: raise ValueError("Couldn't create datasource for data of type " + \ str(type(data))) self.datasources[name] = ds return self.datasources[name] #------------------------------------------------------------------------ # Event handlers #------------------------------------------------------------------------ def _color_mapper_changed(self): for plist in self.plots.values(): for plot in plist: plot.color_mapper = self.color_mapper self.invalidate_draw() def _data_changed(self, old, new): if old: old.on_trait_change(self._data_update_handler, "data_changed", remove=True) if new: new.on_trait_change(self._data_update_handler, "data_changed") def _data_update_handler(self, name, event): # event should be a dict with keys "added", "removed", and "changed", # per the comments in AbstractPlotData. if event.has_key("added"): pass if event.has_key("removed"): pass if event.has_key("changed"): for name in event["changed"]: if self.datasources.has_key(name): source = self.datasources[name] source.set_data(self.data.get_data(name)) def _plots_items_changed(self, event): if self.legend: self.legend.plots = self.plots def _legend_changed(self, old, new): self._overlay_change_helper(old, new) if new: new.plots = self.plots def _handle_range_changed(self, name, old, new): """ Overrides the DataView default behavior. Primarily changes how the list of renderers is looked up. """ mapper = getattr(self, name+"_mapper") if mapper.range == old: mapper.range = new if old is not None: for datasource in old.sources[:]: old.remove(datasource) if new is not None: new.add(datasource) range_name = name + "_range" for renderer in itertools.chain(*self.plots.values()): if hasattr(renderer, range_name): setattr(renderer, range_name, new) #------------------------------------------------------------------------ # Property getters and setters #------------------------------------------------------------------------ def _set_legend_alignment(self, align): if self.legend: self.legend.align = align def _get_legend_alignment(self): if self.legend: return self.legend.align else: return None
class CellCropController(BaseImageController): zero=Int(0) template_plot = Instance(BasePlotContainer) template_data = Instance(ArrayPlotData) template_size = Range(low=2, high=512, value=64, cols=4) template_top = Range(low='zero',high='max_pos_y', value=20, cols=4) template_left = Range(low='zero',high='max_pos_x', value=20, cols=4) peaks = Dict({}) ShowCC = Bool(False) max_pos_x = Int(256) max_pos_y = Int(256) is_square = Bool(True) peak_width = Range(low=2, high=200, value=10) numpeaks_total = Int(0,cols=5) numpeaks_img = Int(0,cols=5) _session_id = String('') def __init__(self, parent, treasure_chest=None, data_path='/rawdata', *args, **kw): super(CellCropController, self).__init__(parent, treasure_chest, data_path, *args, **kw) if self.chest is not None: self.numfiles = len(self.nodes) if self.numfiles > 0: self.init_plot() print "initialized plot for data in %s" % data_path def data_updated(self): # reinitialize data self.__init__(parent = self.parent, treasure_chest=self.chest, data_path=self.data_path) def init_plot(self): self.plotdata.set_data('imagedata', self.get_active_image()) self.plot = self.get_scatter_overlay_plot(array_plot_data=self.plotdata, title=self.get_active_name(), tools=['csr','colorbar','zoom','pan'] ) # pick an initial template with default parameters self.template_data = ArrayPlotData() self.template_plot = Plot(self.template_data, default_origin="top left") self.template_data.set_data('imagedata', self.get_active_image()[ self.template_top:self.template_top + self.template_size, self.template_left:self.template_left + self.template_size ] ) self.template_plot.img_plot('imagedata', title = "Template") self.template_plot.aspect_ratio=1 #square templates self.template_filename = self.get_active_name() self._get_max_positions() @on_trait_change("selected_index, ShowCC") def update_image(self): if self.ShowCC: CC = cv_funcs.xcorr(self.template_data.get_data('imagedata'), self.get_active_image()) self.plotdata.set_data("imagedata",CC) self.plot = self.get_scatter_overlay_plot(array_plot_data=self.plotdata, title=self.get_active_name(), tools=['csr','zoom','pan', 'colorbar'], ) self.plot.aspect_ratio = (float(CC.shape[1])/ CC.shape[0]) self.set_plot_title("Cross correlation of " + self.get_active_name()) grid_data_source = self._base_plot.range2d.sources[0] grid_data_source.set_data(np.arange(CC.shape[1]), np.arange(CC.shape[0])) else: self.plotdata.set_data("imagedata", self.get_active_image()) self.plot = self.get_scatter_overlay_plot(array_plot_data=self.plotdata, title=self.get_active_name(), tools=['csr','zoom','pan', 'colorbar'], ) self.plot.aspect_ratio = (float(self.get_active_image().shape[1])/ self.get_active_image().shape[0]) self.set_plot_title(self.get_active_name()) grid_data_source = self._base_plot.range2d.sources[0] grid_data_source.set_data(np.arange(self.get_active_image().shape[1]), np.arange(self.get_active_image().shape[0])) def update_CC(self): if self.ShowCC: CC = cv_funcs.xcorr(self.template_data.get_data('imagedata'), self.get_active_image()) self.plotdata.set_data("imagedata",CC) @on_trait_change('template_left, template_top, template_size') def update_template_data(self): self.template_data.set_data('imagedata', self.get_active_image()[ self.template_top:self.template_top + self.template_size, self.template_left:self.template_left + self.template_size ] ) self.template_filename = self.get_active_name() if self.numpeaks_total>0: print "clearing peaks" self.peaks={} # when template data changes, we should check whether to update the # cross correlation plot, which depends on the template self.update_CC() @on_trait_change('selected_index, template_size') def _get_max_positions(self): max_pos_x=self.get_active_image().shape[-1]-self.template_size-1 if max_pos_x>0: self.max_pos_x = int(max_pos_x) max_pos_y=self.get_active_image().shape[-2]-self.template_size-1 if max_pos_y>0: self.max_pos_y = int(max_pos_y) @on_trait_change('template_left, template_top') def update_csr_position(self): if self.template_left>0: self._csr.current_position=self.template_left,self.template_top pass @on_trait_change('_csr:current_position') def update_top_left(self): if self._csr.current_position[0]>0 or self._csr.current_position[1]>0: if self._csr.current_position[0]>self.max_pos_x: if self._csr.current_position[1]<self.max_pos_y: self.template_top=self._csr.current_position[1] else: self._csr.current_position=self.max_pos_x, self.max_pos_y elif self._csr.current_position[1]>self.max_pos_y: self.template_left,self.template_top=self._csr.current_position[0],self.max_pos_y else: self.template_left,self.template_top=self._csr.current_position @on_trait_change('_colorbar_selection:selection') def update_thresh(self): try: thresh=self._colorbar_selection.selection self.thresh=thresh scatter_renderer=self._scatter_plot.plots['scatter_plot'][0] scatter_renderer.color_data.metadata['selections']=thresh self.thresh_lower=thresh[0] self.thresh_upper=thresh[1] scatter_renderer.color_data.metadata_changed={'selections':thresh} self.plot.request_redraw() except: pass @on_trait_change('thresh_upper,thresh_lower') def manual_thresh_update(self): self.thresh=[self.thresh_lower,self.thresh_upper] scatter_renderer=self._scatter_plot.plots['scatter_plot'][0] scatter_renderer.color_data.metadata['selections']=self.thresh scatter_renderer.color_data.metadata_changed={'selections':self.thresh} self.plot.request_redraw() @on_trait_change('peaks, _colorbar_selection:selection, selected_index') def calc_numpeaks(self): try: thresh=self._colorbar_selection.selection self.thresh=thresh except: thresh=[] if thresh==[] or thresh==() or thresh==None: thresh=(-1,1) self.numpeaks_total=int(np.sum([np.sum(np.ma.masked_inside( self.peaks[image_id][:,2], thresh[0], thresh[1]).mask) for image_id in self.peaks.keys() ] ) ) try: self.numpeaks_img=int(np.sum(np.ma.masked_inside( self.peaks[self.get_active_name()][:,2], thresh[0],thresh[1]).mask)) except: self.numpeaks_img=0 @on_trait_change('peaks, selected_index') def update_scatter_plot(self): data = self.plotdata.get_data('imagedata') aspect_ratio = (float(data.shape[1])/ data.shape[0]) if self.get_active_name() in self.peaks: self.plotdata.set_data("index",self.peaks[self.get_active_name()][:,0]) self.plotdata.set_data("value",self.peaks[self.get_active_name()][:,1]) self.plotdata.set_data("color",self.peaks[self.get_active_name()][:,2]) self.plot = self.get_scatter_overlay_plot(array_plot_data=self.plotdata, tools=['zoom','pan','colorbar']) scatter_renderer = self._scatter_plot.plots['scatter_plot'][0] scatter_renderer.color_data.metadata['selections']=self.thresh scatter_renderer.color_data.metadata_changed={'selections':self.thresh} else: if 'index' in self.plotdata.arrays: self.plotdata.del_data('index') # value will implicitly exist if value exists. self.plotdata.del_data('value') if 'color' in self.plotdata.arrays: self.plotdata.del_data('color') self.plot = self.get_scatter_overlay_plot(array_plot_data=self.plotdata, ) def locate_peaks(self): peaks={} for idx in xrange(self.numfiles): self.set_active_index(idx) CC = cv_funcs.xcorr(self.template_data.get_data("imagedata"), self.get_active_image()) # pks=pc.two_dim_findpeaks((CC-CC.min())*255, medfilt_radius=None, alpha=1, coords_list=[], ) pks=pc.flatten_peak_list(pks) pks[:,2]=pks[:,2]/255+CC.min() peaks[self.get_active_name()]=pks self.peaks=peaks def mask_peaks(self,image_id): mpeaks=np.ma.asarray(self.peaks[image_id]) mpeaks[:,2]=np.ma.masked_outside(mpeaks[:,2],self.thresh[0],self.thresh[1]) return mpeaks def crop_cells(self): rows = self.chest.root.cell_description.nrows if rows > 0: # remove the table self.chest.removeNode('/cell_description') try: # remove the table of peak characteristics - they are not valid. self.chest.removeNode('/cell_peaks') except: pass # recreate it self.chest.createTable('/', 'cell_description', CellsTable) # remove all existing entries in the data group for node in self.chest.listNodes('/cells'): self.chest.removeNode('/cells/' + node.name) # store the template template_data = self.template_data.get_data('imagedata') self.parent.add_cell_data(template_data, name="template") # TODO: set attribute that tells where the template came from row = self.chest.root.cell_description.row files=[] for idx in xrange(self.numfiles): # filter the peaks that are outside the selected threshold self.set_active_index(idx) active_image = self.get_active_image() peaks=np.ma.compress_rows(self.mask_peaks(self.get_active_name())) files.append(self.get_active_name()) tmp_sz=self.template_size data=np.zeros((peaks.shape[0],tmp_sz,tmp_sz), dtype=active_image.dtype) if data.shape[0] >0: for i in xrange(peaks.shape[0]): # store the peak in the table row['file_idx'] = i row['input_data'] = self.data_path row['filename'] = self.get_active_name() row['x_coordinate'] = peaks[i, 1] row['y_coordinate'] = peaks[i, 0] row.append() # crop the cells from the given locations data[i,:,:]=active_image[peaks[i, 1]:peaks[i, 1] + tmp_sz, peaks[i, 0]:peaks[i, 0] + tmp_sz] self.chest.root.cell_description.flush() self.parent.add_cell_data(data, name=self.get_active_name()) # insert the data (one 3d array per file) self.chest.setNodeAttr('/cell_description', 'threshold', (self.thresh_lower, self.thresh_upper)) self.chest.setNodeAttr('/cell_description', 'template_position', (self.template_left, self.template_top)) self.chest.setNodeAttr('/cell_description', 'template_filename', self.template_filename) self.chest.setNodeAttr('/cell_description', 'template_size', (self.template_size)) self.chest.root.cell_description.flush() self.chest.flush() average_data = np.average(data,axis=0).squeeze() self.parent.add_cell_data(average_data, name="average") self.parent.update_cell_data() self.log_action(action="crop cells", files=files, thresh=self.thresh, template_position=(self.template_left, self.template_top), template_size=self.template_size, template_filename=self.template_filename) Application.instance().end_session(self._session_id)
class Plot1D(DataView): #------------------------------------------------------------------------ # Data-related traits #------------------------------------------------------------------------ # The PlotData instance that drives this plot. data = Instance(AbstractPlotData) # Mapping of data names from self.data to their respective datasources. datasources = Dict(Str, Instance(AbstractDataSource)) #------------------------------------------------------------------------ # General plotting traits #------------------------------------------------------------------------ # Mapping of plot names to *lists* of plot renderers. plots = Dict(Str, List) # The default index to use when adding new subplots. default_index = Instance(AbstractDataSource) # List of colors to cycle through when auto-coloring is requested. Picked # and ordered to be red-green color-blind friendly, though should not # be an issue for blue-yellow. auto_colors = List(["white", "red" , "blue","green", "lightblue", "pink", "silver"]) # index into auto_colors list _auto_color_idx = Int(-1) _auto_edge_color_idx = Int(-1) _auto_face_color_idx = Int(-1) # Mapping of renderer type string to renderer class # This can be overriden to customize what renderer type the Plot # will instantiate for its various plotting methods. renderer_map = Dict(dict(line = LinePlot, scatter = ScatterPlot)) #------------------------------------------------------------------------ # Annotations and decorations #------------------------------------------------------------------------ # The legend on the plot. legend = Instance(Legend) # Convenience attribute for legend.align; can be "ur", "ul", "ll", "lr". legend_alignment = Property def __init__(self, data=None, grid_color='yellow', **kwtraits): if 'origin' in kwtraits: self.default_origin = kwtraits.pop('origin') if 'bgcolor' not in kwtraits: kwtraits['bgcolor'] = 'black' super(Plot1D, self).__init__(**kwtraits) self.x_grid.line_color = grid_color self.y_grid.line_color = grid_color self.padding = (65,10,10,50) if data is not None: if isinstance(data, AbstractPlotData): self.data = data elif type(data) in (ndarray, tuple, list): self.data = ArrayPlotData(data) else: raise ValueError, "Don't know how to create PlotData for data" \ "of type " + str(type(data)) if not self.legend: self.legend = Legend(visible=False, align="ur", error_icon="blank", padding=10, component=self) # ensure that we only get displayed once by new_window() self._plot_ui_info = None return def plot(self, data, type="line", name=None, index_scale="linear", value_scale="linear", origin=None, **styles): """ Adds a new sub-plot using the given data and plot style. Parameters ========== data : string, tuple(string), list(string) The data to be plotted. The type of plot and the number of arguments determines how the arguments are interpreted: one item: (line/scatter) The data is treated as the value and self.default_index is used as the index. If **default_index** does not exist, one is created from arange(len(*data*)) two or more items: (line/scatter) Interpreted as (index, value1, value2, ...). Each index,value pair forms a new plot of the type specified. type : comma-delimited string of "line", "scatter", "cmap_scatter" The types of plots to add. name : string The name of the plot. If None, then a default one is created (usually "plotNNN"). index_scale : string The type of scale to use for the index axis. If not "linear", then a log scale is used. value_scale : string The type of scale to use for the value axis. If not "linear", then a log scale is used. origin : string Which corner the origin of this plot should occupy: "bottom left", "top left", "bottom right", "top right" styles : series of keyword arguments attributes and values that apply to one or more of the plot types requested, e.g.,'line_color' or 'line_width'. Examples ======== :: plot("my_data", type="line", name="myplot", color=lightblue) plot(("x-data", "y-data"), type="scatter") plot(("x", "y1", "y2", "y3")) Returns ======= [renderers] -> list of renderers created in response to this call to plot() """ if len(data) == 0: return if isinstance(data, basestring): data = (data,) self.index_scale = index_scale self.value_scale = value_scale # TODO: support lists of plot types plot_type = type if name is None: name = self._make_new_plot_name() if origin is None: origin = self.default_origin if plot_type in ("line", "scatter"): if len(data) == 1: if self.default_index is None: # Create the default index based on the length of the first # data series value = self._get_or_create_datasource(data[0],sort_order="ascending") self.default_index = ArrayDataSource(arange(len(value.get_data())), sort_order="ascending") self.index_range.add(self.default_index) index = self.default_index else: index = self._get_or_create_datasource(data[0]) if self.default_index is None: self.default_index = index self.index_range.add(index) data = data[1:] new_plots = [] simple_plot_types = ("line", "scatter") for value_name in data: value = self._get_or_create_datasource(value_name) self.value_range.add(value) if plot_type in simple_plot_types: cls = self.renderer_map[plot_type] # handle auto-coloring request if styles.get("color") == "auto": self._auto_color_idx = \ (self._auto_color_idx + 1) % len(self.auto_colors) styles["color"] = self.auto_colors[self._auto_color_idx] else: raise ValueError("Unhandled plot type: " + plot_type) if self.index_scale == "linear": imap = LinearMapper(range=self.index_range, stretch_data=self.index_mapper.stretch_data) else: imap = LogMapper(range=self.index_range, stretch_data=self.index_mapper.stretch_data) if self.value_scale == "linear": vmap = LinearMapper(range=self.value_range, stretch_data=self.value_mapper.stretch_data) else: vmap = LogMapper(range=self.value_range, stretch_data=self.value_mapper.stretch_data) plot = cls(index=index, value=value, index_mapper=imap, value_mapper=vmap, orientation=self.orientation, origin = origin, **styles) self.add(plot) new_plots.append(plot) self.plots[name] = new_plots else: raise ValueError("Unknown plot type: " + plot_type) return self.plots[name] def delplot(self, *names): """ Removes the named sub-plots. """ # This process involves removing the plots, then checking the index range # and value range for leftover datasources, and removing those if necessary. # Remove all the renderers from us (container) and create a set of the # datasources that we might have to remove from the ranges deleted_sources = set() for renderer in itertools.chain(*[self.plots.pop(name) for name in names]): self.remove(renderer) deleted_sources.add(renderer.index) deleted_sources.add(renderer.value) #~ #Go back in the auto-coloring index for name in names: self._auto_color_idx = \ (self._auto_color_idx - 1) % len(self.auto_colors) # Cull the candidate list of sources to remove by checking the other plots sources_in_use = set() for p in itertools.chain(*self.plots.values()): sources_in_use.add(p.index) sources_in_use.add(p.value) unused_sources = deleted_sources - sources_in_use - set([None]) # Remove the unused sources from all ranges and delete them for source in unused_sources: if source.index_dimension == "scalar": # Try both index and range, it doesn't hurt self.index_range.remove(source) self.value_range.remove(source) elif source.index_dimension == "image": self.range2d.remove(source) else: warnings.warn("Couldn't remove datasource from datarange.") #Remove the unused sources from the data sources for name in names: if self.datasources[name] in unused_sources: del self.datasources[name] return #------------------------------------------------------------------------ # Private methods #------------------------------------------------------------------------ def _make_new_plot_name(self): """ Returns a string that is not already used as a plot title. """ n = len(self.plots) plot_template = "plot%d" while 1: name = plot_template % n if name not in self.plots: break else: n += 1 return name def _get_or_create_datasource(self, name, sort_order = 'none'): """ Returns the data source associated with the given name, or creates it if it doesn't exist. """ if name not in self.datasources: data = self.data.get_data(name) if type(data) in (list, tuple): data = array(data) if isinstance(data, ndarray): if len(data.shape) == 1: ds = ArrayDataSource(data, sort_order=sort_order) else: raise ValueError("Unhandled array shape in creating new plot: " \ + str(data.shape)) elif isinstance(data, AbstractDataSource): ds = data else: raise ValueError("Couldn't create datasource for data of type " + \ str(type(data))) self.datasources[name] = ds return self.datasources[name] #------------------------------------------------------------------------ # Event handlers #------------------------------------------------------------------------ def _data_changed(self, old, new): if old: old.on_trait_change(self._data_update_handler, "data_changed", remove=True, dispatch = 'ui') if new: new.on_trait_change(self._data_update_handler, "data_changed", dispatch = 'ui') def _data_update_handler(self, name, event): # event should be a dict with keys "added", "removed", and "changed", # per the comments in AbstractPlotData. if event.has_key("added"): pass if event.has_key("removed"): pass if event.has_key("changed"): for name in event["changed"]: if self.datasources.has_key(name): if self.datasources[name] in self.index_range.sources: index = self.index_range if (index.low_setting == 'auto' and\ index.high_setting != 'auto'): index.set_high('auto') elif(index.low_setting != 'auto' and\ index.high_setting == 'auto'): index.set_low('auto') if self.datasources[name] in self.value_range.sources: value = self.value_range if (value.low_setting == 'auto' and\ value.high_setting != 'auto'): value.set_high('auto') elif(value.low_setting != 'auto' and\ value.high_setting == 'auto'): value.set_low('auto') source = self.datasources[name] source.set_data(self.data.get_data(name)) def _plots_items_changed(self, event): if self.legend: self.legend.plots = self.plots def _index_scale_changed(self, old, new): if old is None: return if new == old: return if not self.range2d: return if self.index_scale == "linear": imap = LinearMapper(range=self.index_range, screen_bounds=self.index_mapper.screen_bounds, stretch_data=self.index_mapper.stretch_data) else: imap = LogMapper(range=self.index_range, screen_bounds=self.index_mapper.screen_bounds, stretch_data=self.index_mapper.stretch_data) self.index_mapper = imap for key in self.plots: for plot in self.plots[key]: if not isinstance(plot, BaseXYPlot): raise ValueError("log scale only supported on XY plots") if self.index_scale == "linear": imap = LinearMapper(range=plot.index_range, screen_bounds=plot.index_mapper.screen_bounds, stretch_data=self.index_mapper.stretch_data) else: imap = LogMapper(range=plot.index_range, screen_bounds=plot.index_mapper.screen_bounds, stretch_data=self.index_mapper.stretch_data) plot.index_mapper = imap def _value_scale_changed(self, old, new): if old is None: return if new == old: return if not self.range2d: return if self.value_scale == "linear": vmap = LinearMapper(range=self.value_range, screen_bounds=self.value_mapper.screen_bounds, stretch_data=self.value_mapper.stretch_data) else: vmap = LogMapper(range=self.value_range, screen_bounds=self.value_mapper.screen_bounds, stretch_data=self.value_mapper.stretch_data) self.value_mapper = vmap for key in self.plots: for plot in self.plots[key]: if not isinstance(plot, BaseXYPlot): raise ValueError("log scale only supported on XY plots") if self.value_scale == "linear": vmap = LinearMapper(range=plot.value_range, screen_bounds=plot.value_mapper.screen_bounds, stretch_data=self.value_mapper.stretch_data) else: vmap = LogMapper(range=plot.value_range, screen_bounds=plot.value_mapper.screen_bounds, stretch_data=self.value_mapper.stretch_data) plot.value_mapper = vmap def _legend_changed(self, old, new): self._overlay_change_helper(old, new) if new: new.plots = self.plots def _handle_range_changed(self, name, old, new): """ Overrides the DataView default behavior. Primarily changes how the list of renderers is looked up. """ mapper = getattr(self, name+"_mapper") if mapper.range == old: mapper.range = new if old is not None: for datasource in old.sources[:]: old.remove(datasource) if new is not None: new.add(datasource) range_name = name + "_range" for renderer in itertools.chain(*self.plots.values()): if hasattr(renderer, range_name): setattr(renderer, range_name, new) #------------------------------------------------------------------------ # Property getters and setters #------------------------------------------------------------------------ def _set_legend_alignment(self, align): if self.legend: self.legend.align = align def _get_legend_alignment(self): if self.legend: return self.legend.align else: return None
class CellCropController(BaseImageController): zero = Int(0) template_plot = Instance(BasePlotContainer) template_data = Instance(ArrayPlotData) template_size = Range(low=2, high=512, value=64, cols=4) template_top = Range(low='zero', high='max_pos_y', value=20, cols=4) template_left = Range(low='zero', high='max_pos_x', value=20, cols=4) peaks = Dict({}) ShowCC = Bool(False) max_pos_x = Int(256) max_pos_y = Int(256) is_square = Bool(True) peak_width = Range(low=2, high=200, value=10) numpeaks_total = Int(0, cols=5) numpeaks_img = Int(0, cols=5) _session_id = String('') def __init__(self, parent, treasure_chest=None, data_path='/rawdata', *args, **kw): super(CellCropController, self).__init__(parent, treasure_chest, data_path, *args, **kw) if self.chest is not None: self.numfiles = len(self.nodes) if self.numfiles > 0: self.init_plot() def data_updated(self): # reinitialize data self.__init__(parent=self.parent, treasure_chest=self.chest, data_path=self.data_path) def init_plot(self): self.plotdata.set_data('imagedata', self.get_active_image()) self.plot = self.get_scatter_overlay_plot( array_plot_data=self.plotdata, title=self.get_active_name(), tools=['csr', 'colorbar', 'zoom', 'pan']) # pick an initial template with default parameters self.template_data = ArrayPlotData() self.template_plot = Plot(self.template_data, default_origin="top left") self.template_data.set_data( 'imagedata', self.get_active_image()[self.template_top:self.template_top + self.template_size, self.template_left:self.template_left + self.template_size]) self.template_plot.img_plot('imagedata', title="Template") self.template_plot.aspect_ratio = 1 #square templates self.template_filename = self.get_active_name() self._get_max_positions() @on_trait_change("selected_index, ShowCC") def update_image(self): if self.ShowCC: CC = cv_funcs.xcorr(self.template_data.get_data('imagedata'), self.get_active_image()) self.plotdata.set_data("imagedata", CC) self.plot = self.get_scatter_overlay_plot( array_plot_data=self.plotdata, title=self.get_active_name(), tools=['csr', 'zoom', 'pan', 'colorbar'], ) self.plot.aspect_ratio = (float(CC.shape[1]) / CC.shape[0]) self.set_plot_title("Cross correlation of " + self.get_active_name()) grid_data_source = self._base_plot.range2d.sources[0] grid_data_source.set_data(np.arange(CC.shape[1]), np.arange(CC.shape[0])) else: self.plotdata.set_data("imagedata", self.get_active_image()) self.plot = self.get_scatter_overlay_plot( array_plot_data=self.plotdata, title=self.get_active_name(), tools=['csr', 'zoom', 'pan', 'colorbar'], ) self.plot.aspect_ratio = (float(self.get_active_image().shape[1]) / self.get_active_image().shape[0]) self.set_plot_title(self.get_active_name()) grid_data_source = self._base_plot.range2d.sources[0] grid_data_source.set_data( np.arange(self.get_active_image().shape[1]), np.arange(self.get_active_image().shape[0])) def update_CC(self): if self.ShowCC: CC = cv_funcs.xcorr(self.template_data.get_data('imagedata'), self.get_active_image()) self.plotdata.set_data("imagedata", CC) @on_trait_change('template_left, template_top, template_size') def update_template_data(self): self.template_data.set_data( 'imagedata', self.get_active_image()[self.template_top:self.template_top + self.template_size, self.template_left:self.template_left + self.template_size]) self.template_filename = self.get_active_name() if self.numpeaks_total > 0: print "clearing peaks" self.peaks = {} # when template data changes, we should check whether to update the # cross correlation plot, which depends on the template self.update_CC() @on_trait_change('selected_index, template_size') def _get_max_positions(self): max_pos_x = self.get_active_image().shape[-1] - self.template_size - 1 if max_pos_x > 0: self.max_pos_x = int(max_pos_x) max_pos_y = self.get_active_image().shape[-2] - self.template_size - 1 if max_pos_y > 0: self.max_pos_y = int(max_pos_y) @on_trait_change('template_left, template_top') def update_csr_position(self): if self.template_left > 0: self._csr.current_position = self.template_left, self.template_top pass @on_trait_change('_csr:current_position') def update_top_left(self): if self._csr.current_position[0] > 0 or self._csr.current_position[ 1] > 0: if self._csr.current_position[0] > self.max_pos_x: if self._csr.current_position[1] < self.max_pos_y: self.template_top = self._csr.current_position[1] else: self._csr.current_position = self.max_pos_x, self.max_pos_y elif self._csr.current_position[1] > self.max_pos_y: self.template_left, self.template_top = self._csr.current_position[ 0], self.max_pos_y else: self.template_left, self.template_top = self._csr.current_position @on_trait_change('_colorbar_selection:selection') def update_thresh(self): try: thresh = self._colorbar_selection.selection self.thresh = thresh scatter_renderer = self._scatter_plot.plots['scatter_plot'][0] scatter_renderer.color_data.metadata['selections'] = thresh self.thresh_lower = thresh[0] self.thresh_upper = thresh[1] scatter_renderer.color_data.metadata_changed = { 'selections': thresh } self.plot.request_redraw() except: pass @on_trait_change('thresh_upper,thresh_lower') def manual_thresh_update(self): self.thresh = [self.thresh_lower, self.thresh_upper] scatter_renderer = self._scatter_plot.plots['scatter_plot'][0] scatter_renderer.color_data.metadata['selections'] = self.thresh scatter_renderer.color_data.metadata_changed = { 'selections': self.thresh } self.plot.request_redraw() @on_trait_change('peaks, _colorbar_selection:selection, selected_index') def calc_numpeaks(self): try: thresh = self._colorbar_selection.selection self.thresh = thresh except: thresh = [] if thresh == [] or thresh == () or thresh == None: thresh = (-1, 1) self.numpeaks_total = int( np.sum([ np.sum( np.ma.masked_inside(self.peaks[image_id][:, 2], thresh[0], thresh[1]).mask) for image_id in self.peaks.keys() ])) try: self.numpeaks_img = int( np.sum( np.ma.masked_inside( self.peaks[self.get_active_name()][:, 2], thresh[0], thresh[1]).mask)) except: self.numpeaks_img = 0 @on_trait_change('peaks, selected_index') def update_scatter_plot(self): data = self.plotdata.get_data('imagedata') aspect_ratio = (float(data.shape[1]) / data.shape[0]) if self.get_active_name() in self.peaks: self.plotdata.set_data("index", self.peaks[self.get_active_name()][:, 1]) self.plotdata.set_data("value", self.peaks[self.get_active_name()][:, 0]) self.plotdata.set_data("color", self.peaks[self.get_active_name()][:, 2]) self.plot = self.get_scatter_overlay_plot( array_plot_data=self.plotdata, tools=['zoom', 'pan', 'colorbar']) scatter_renderer = self._scatter_plot.plots['scatter_plot'][0] scatter_renderer.color_data.metadata['selections'] = self.thresh scatter_renderer.color_data.metadata_changed = { 'selections': self.thresh } else: if 'index' in self.plotdata.arrays: self.plotdata.del_data('index') # value will implicitly exist if value exists. self.plotdata.del_data('value') if 'color' in self.plotdata.arrays: self.plotdata.del_data('color') self.plot = self.get_scatter_overlay_plot( array_plot_data=self.plotdata, ) def locate_peaks(self): peaks = {} for idx in xrange(self.numfiles): self.set_active_index(idx) CC = cv_funcs.xcorr(self.template_data.get_data("imagedata"), self.get_active_image()) # pks = pc.two_dim_findpeaks((CC - CC.min()) * 255, xc_filter=False) pks[:, 2] = pks[:, 2] / 255 + CC.min() peaks[self.get_active_name()] = pks self.peaks = peaks def mask_peaks(self, image_id): mpeaks = np.ma.asarray(self.peaks[image_id]) mpeaks[:, 2] = np.ma.masked_outside(mpeaks[:, 2], self.thresh[0], self.thresh[1]) return mpeaks def crop_cells(self): rows = self.chest.root.cell_description.nrows if rows > 0: # remove the table self.chest.remove_node('/cell_description') try: # remove the table of peak characteristics - they are not valid. self.chest.remove_node('/cell_peaks') except: pass # recreate it self.chest.create_table('/', 'cell_description', CellsTable) # remove all existing entries in the data group for node in self.chest.list_nodes('/cells'): self.chest.remove_node('/cells/' + node.name) # store the template template_data = self.template_data.get_data('imagedata') self.parent.add_cell_data(template_data, name="template") # TODO: set attribute that tells where the template came from row = self.chest.root.cell_description.row files = [] for idx in xrange(self.numfiles): # filter the peaks that are outside the selected threshold self.set_active_index(idx) active_image = self.get_active_image() peaks = np.ma.compress_rows(self.mask_peaks( self.get_active_name())) files.append(self.get_active_name()) tmp_sz = self.template_size data = np.zeros((peaks.shape[0], tmp_sz, tmp_sz), dtype=active_image.dtype) if data.shape[0] > 0: for i in xrange(peaks.shape[0]): # store the peak in the table row['file_idx'] = i row['input_data'] = self.data_path row['filename'] = self.get_active_name() row['x_coordinate'] = peaks[i, 0] row['y_coordinate'] = peaks[i, 1] row.append() # crop the cells from the given locations data[i, :, :] = active_image[ int(peaks[i, 0]):int(peaks[i, 0] + tmp_sz), int(peaks[i, 1]):int(peaks[i, 1] + tmp_sz)] self.chest.root.cell_description.flush() self.parent.add_cell_data(data, name=self.get_active_name()) # insert the data (one 3d array per file) self.chest.set_node_attr( '/cell_description', 'threshold', (self.thresh_lower, self.thresh_upper)) self.chest.set_node_attr( '/cell_description', 'template_position', (self.template_left, self.template_top)) self.chest.set_node_attr('/cell_description', 'template_filename', self.template_filename) self.chest.set_node_attr('/cell_description', 'template_size', (self.template_size)) self.chest.root.cell_description.flush() self.chest.flush() average_data = np.average(data, axis=0).squeeze() self.parent.add_cell_data(average_data, name="average") row = self.chest.root.cell_description.row row['file_idx'] = 0 row['input_data'] = self.data_path row['filename'] = "average" row['x_coordinate'] = 0 row['y_coordinate'] = 0 row.append() self.chest.root.cell_description.flush() self.parent.update_cell_data() self.parent.add_image_data(average_data, "average") self.log_action(action="crop cells", files=files, thresh=self.thresh, template_position=(self.template_left, self.template_top), template_size=self.template_size, template_filename=self.template_filename) Application.instance().end_session(self._session_id)
class Plotter2D(HasPreferenceTraits): plot = Instance(Plot2D) colorbar = Instance(ColorBar) container = Instance(HPlotContainer) zoom_bar_plot = Instance(ZoomBar) zoom_bar_colorbar = Instance(ZoomBar) pan_bar = Instance(PanBar) range_bar = Instance(RangeBar) data = Instance(ArrayPlotData,()) x_min = Float(0.0) x_max = Float(1.0) y_min = Float(0.0) y_max = Float(1.0) add_contour = Bool(False) x_axis_label = Str y_axis_label = Str c_axis_label = Str x_axis_formatter = Instance(AxisFormatter) y_axis_formatter = Instance(AxisFormatter) c_axis_formatter = Instance(AxisFormatter) colormap = Enum(color_map_name_dict.keys(), preference = 'async') _cmap = Trait(Greys, Callable) update_index = Event traits_view = View( Group( Group( UItem('container', editor=ComponentEditor()), VGroup( UItem('zoom_bar_colorbar',style = 'custom'), ), orientation = 'horizontal', ), Group( Group( UItem('zoom_bar_plot', style = 'custom'), UItem('pan_bar', style = 'custom'), UItem('range_bar', style = 'custom'), Group( UItem('colormap'), label = 'Color map', ), orientation = 'horizontal', ), orientation = 'vertical', ), orientation = 'vertical', ), resizable=True ) preference_view = View( HGroup( VGroup( Item('x_axis_formatter', style = 'custom', editor = InstanceEditor( view = 'preference_view'), label = 'X axis', ), Item('y_axis_formatter', style = 'custom', editor = InstanceEditor( view = 'preference_view'), label = 'Y axis', ), ), Item('c_axis_formatter', style = 'custom', editor = InstanceEditor( view = 'preference_view'), label = 'C axis', ), show_border = True, label = 'Axis format', ), ) def __init__(self, **kwargs): super(Plotter2D, self).__init__(**kwargs) self.x_axis_formatter = AxisFormatter(pref_name = 'X axis format', pref_parent = self) self.y_axis_formatter = AxisFormatter(pref_name = 'Y axis format', pref_parent = self) self.c_axis_formatter = AxisFormatter(pref_name = 'C axis format', pref_parent = self) self.data = ArrayPlotData() self.plot = Plot2D(self.data) self.plot.padding = (80,50,10,40) self.plot.x_axis.tick_label_formatter =\ self.x_axis_formatter.float_format self.plot.y_axis.tick_label_formatter =\ self.y_axis_formatter.float_format self.pan_bar = PanBar(self.plot) self.zoom_bar_plot = zoom_bar(self.plot,x = True,\ y = True, reset = True ) #Dummy plot so that the color bar can be correctly initialized xs = linspace(-2, 2, 600) ys = linspace(-1.2, 1.2, 300) self.x_min = xs[0] self.x_max = xs[-1] self.y_min = ys[0] self.y_max = ys[-1] x, y = meshgrid(xs,ys) z = tanh(x*y/6)*cosh(exp(-y**2)*x/3) z = x*y self.data.set_data('c',z) self.plot.img_plot(('c'),\ name = 'c', colormap = self._cmap, xbounds = (self.x_min,self.x_max), ybounds = (self.y_min,self.y_max), ) # Create the colorbar, the appropriate range and colormap are handled # at the plot creation self.colorbar = ColorBar( index_mapper = LinearMapper(range = \ self.plot.color_mapper.range), color_mapper=self.plot.color_mapper, plot = self.plot, orientation='v', resizable='v', width=20, padding=10) self.colorbar.padding_top = self.plot.padding_top self.colorbar.padding_bottom = self.plot.padding_bottom self.colorbar._axis.tick_label_formatter =\ self.c_axis_formatter.float_format self.container = HPlotContainer(self.plot, self.colorbar, use_backbuffer=True, bgcolor="lightgray") # Add pan and zoom tools to the colorbar self.colorbar.tools.append(PanTool(self.colorbar,\ constrain_direction="y",\ constrain=True) ) self.zoom_bar_colorbar = zoom_bar(self.colorbar, box = False, reset=True, orientation = 'vertical' ) # Add the range bar now that we are sure that we have a color_mapper self.range_bar = RangeBar(self.plot) self.x_axis_label = 'X' self.y_axis_label = 'Y' self.c_axis_label = 'C' self.sync_trait('x_axis_label',self.range_bar,alias = 'x_name') self.sync_trait('y_axis_label',self.range_bar,alias = 'y_name') self.sync_trait('c_axis_label',self.range_bar,alias = 'c_name') #Dynamically bing the update methods for trait likely to be updated #from other thread self.on_trait_change(self.new_x_label, 'x_axis_label', dispatch = 'ui') self.on_trait_change(self.new_y_label, 'y_axis_label', dispatch = 'ui') self.on_trait_change(self.new_c_label, 'c_axis_label', dispatch = 'ui') self.on_trait_change(self.new_x_axis_format, 'x_axis_formatter.+', dispatch = 'ui') self.on_trait_change(self.new_y_axis_format, 'y_axis_formatter.+', dispatch = 'ui') self.on_trait_change(self.new_c_axis_format, 'c_axis_formatter.+', dispatch = 'ui') self.on_trait_change(self._update_plots_index, 'update_index', dispatch = 'ui') #set the default colormap in the editor self.colormap = 'Blues' self.preference_init() #@on_trait_change('x_axis_label', dispatch = 'ui') def new_x_label(self,new): self.plot.x_axis.title = new #@on_trait_change('y_axis_label', dispatch = 'ui') def new_y_label(self,new): self.plot.y_axis.title = new #@on_trait_change('c_axis_label', dispatch = 'ui') def new_c_label(self,new): self.colorbar._axis.title = new @on_trait_change('colormap') def new_colormap(self, new): self._cmap = color_map_name_dict[new] for plots in self.plot.plots.itervalues(): for plot in plots: if isinstance(plot,ImagePlot) or\ isinstance(plot,CMapImagePlot) or\ isinstance(plot,ContourPolyPlot): value_range = plot.color_mapper.range plot.color_mapper = self._cmap(value_range) self.plot.color_mapper = self._cmap(value_range) self.container.request_redraw() #@on_trait_change('x_axis_formatter', dispatch = 'ui') def new_x_axis_format(self): self.plot.x_axis._invalidate() self.plot.invalidate_and_redraw() #@on_trait_change('y_axis_formatter', dispatch = 'ui') def new_y_axis_format(self): self.plot.y_axis._invalidate() self.plot.invalidate_and_redraw() #@on_trait_change('y_axis_formatter', dispatch = 'ui') def new_c_axis_format(self): self.colorbar._axis._invalidate() self.plot.invalidate_and_redraw() def request_update_plots_index(self): self.update_index = True #@on_trait_change('update_index', dispatch = 'ui') def _update_plots_index(self): if 'c' in self.data.list_data(): array = self.data.get_data('c') xs = linspace(self.x_min, self.x_max, array.shape[1] + 1) ys = linspace(self.y_min, self.y_max, array.shape[0] + 1) self.plot.range2d.remove(self.plot.index) self.plot.index = GridDataSource(xs, ys, sort_order=('ascending', 'ascending')) self.plot.range2d.add(self.plot.index) for plots in self.plot.plots.itervalues(): for plot in plots: plot.index = GridDataSource(xs, ys, sort_order=('ascending', 'ascending'))
class Demo(HasTraits): pd = Instance(ArrayPlotData, ()) plot = Instance(HPlotContainer) _load_file = File( find_resource('imageAlignment', '../images/GIRLS-IN-SPACE.jpg', '../images/GIRLS-IN-SPACE.jpg', return_path=True)) _save_file = File load_file_view = View( Item('_load_file'), buttons=OKCancelButtons, kind='livemodal', width=400, resizable=True, ) save_file_view = View( Item('_save_file'), buttons=OKCancelButtons, kind='livemodal', width=400, resizable=True, ) def __init__(self, *args, **kwargs): super(Demo, self).__init__(*args, **kwargs) from imread import imread imarray = imread(find_resource('imageAlignment', '../images/GIRLS-IN-SPACE.jpg', '../images/GIRLS-IN-SPACE.jpg', return_path=True)) self.pd = ArrayPlotData(imagedata=imarray) #self.pd.x_axis.orientation = "top" self.plot = HPlotContainer() titles = ["I KEEP DANCE", "ING ON MY OWN"] self._load() i = 0 for plc in [Plot, Plot]: xs = linspace(0, 334*pi, 333) ys = linspace(0, 334*pi, 333) x, y = meshgrid(xs,ys) z = tanh(x*y/6)*cosh(exp(-y**2)*x/3) z = x*y _pd = ArrayPlotData() _pd.set_data("drawdata", z) _pd.set_data("imagedata", self.pd.get_data('imagedata')) plc = Plot(_pd, title="render_style = hold", padding=50, border_visible=True, overlay_border=True) self.plot.add(plc) plc.img_plot("imagedata", alpha=0.95) # Create a contour polygon plot of the data plc.contour_plot("drawdata", type="poly", poly_cmap=jet, xbounds=(0, 499), ybounds=(0, 582), alpha=0.35) # Create a contour line plot for the data, too plc.contour_plot("drawdata", type="line", xbounds=(0, 499), ybounds=(0, 582), alpha=0.35) # Create a plot data obect and give it this data plc.legend.visible = True plc.title = titles[i] i += 1 #plc.plot(("index", "y0"), name="j_0", color="red", render_style="hold") #plc.padding = 50 #plc.padding_top = 75 plc.tools.append(PanTool(plc)) zoom = ZoomTool(component=plc, tool_mode="box", always_on=False) plc.overlays.append(zoom) # Tweak some of the plot properties plc.padding = 50 #zoom = ZoomTool(component=plot1, tool_mode="box", always_on=False) #plot1.overlays.append(zoom) # Attach some tools to the plot #attach_tools(plc) plc.bg_color = None plc.fill_padding = True def default_traits_view(self): traits_view = View( Group( Item('plot', editor=ComponentEditor(size=size), show_label=False), orientation="vertical"), menubar=MenuBar( Menu(Action(name="Save Plot", action="save"), Action(name="Load Plot", action="load"), Separator(), CloseAction, name="File")), resizable=True, title=title, handler=ImageFileController) return traits_view ''' def _plot_default(self): # Create some x-y data series to plot x = linspace(-2.0, 10.0, 400) self.pd = pd = ArrayPlotData(index=x, y0=jn(0,x), default_origin="top left") # Create some line plots of some of the data plot1 = Plot(self.pd, title="render_style = hold", padding=50, border_visible=True, overlay_border=True) plot1.legend.visible = True plot1.plot(("index", "y0"), name="j_0", color="red", render_style="hold") plot1.padding = 50 plot1.padding_top = 75 plot1.tools.append(PanTool(plot1)) #zoom = ZoomTool(component=plot1, tool_mode="box", always_on=False) #plot1.overlays.append(zoom) # Attach some tools to the plot attach_tools(plot1) # Create a second scatter plot of one of the datasets, linking its # range to the first plot plot2 = Plot(self.pd, range2d=plot1.range2d, title="render_style = connectedhold", padding=50, border_visible=True, overlay_border=True) plot2.plot(('index', 'y0'), color="blue", render_style="connectedhold") plot2.padding = 50 plot2.padding_top = 75 plot2.tools.append(PanTool(plot2)) #zoom = ZoomTool(component=plot2, tool_mode="box", always_on=False) #plot2.overlays.append(zoom) attach_tools(plot2) # Create a container and add our plots container = HPlotContainer() container.add(plot1) container.add(plot2) return container ''' def _save(self): win_size = self.plot.outer_bounds plot_gc = PlotGraphicsContext(win_size) plot_gc.render_component(self.plot) plot_gc.save(self._save_file) def _load(self): try: image = ImageData.fromfile(self._load_file) self.pd.set_data('imagedata', image._data) self.plot.title = "YO DOGG: %s" % os.path.basename(self._load_file) self.plot.request_redraw() except Exception, exc: print "YO DOGG: %s" % exc
class ElectrodeHysteresis(HasTraits): coercitive_field1 = Float(0.01) easy_axis_angle1 = Float(0) easy_axis_angle_input1 = Str('0') coercitive_field2 = Float(0.02) easy_axis_angle2 = Float(0) easy_axis_angle_input2 = Str('Pi/6') field_min = Float(-0.015) field_max = Float(0.015) points = Int(1000) cavity_freq = Float(6.7) cavity_quality = Float(3500) dqd_ready = Bool(False) dqd_t = Float(6.2) dqd_coupling = Float(10) dqd_delta = Float(3) dqd_gamma12 = Float(0.1) dqd_gamma13 = Float(0.1) lande = Float(28) data1 = Instance(ArrayPlotData) plot1 = Instance(Plot) data2 = Instance(ArrayPlotData) plot2 = Instance(Plot) data_diff = Instance(ArrayPlotData) plot_diff = Instance(Plot) data_g = Instance(ArrayPlotData) plot_g = Instance(Plot) compute = Button('Compute') traits_view = View( VGroup( HGroup( VGroup( Item('coercitive_field1'), Item('easy_axis_angle_input1', editor = TextEditor(auto_set = False, enter_set = True), ), Item('easy_axis_angle1', style = 'readonly'), show_border = True, label = 'Electrode 1', ), VGroup( Item('coercitive_field2'), Item('easy_axis_angle_input2', editor = TextEditor(auto_set = False, enter_set = True), ), Item('easy_axis_angle2', style = 'readonly'), show_border = True, label = 'Electrode 2', ), ), HGroup( Item('field_min'), Item('field_max'), Item('points'), ), HGroup( Spring(), Heading('Increasing = blue, decreasing = red'), Spring(), UItem('compute'), ), HGroup( UItem('plot1', editor = ComponentEditor()), UItem('plot2', editor = ComponentEditor()), UItem('plot_diff', editor = ComponentEditor()), ), HGroup( VGroup( HGroup(Item('cavity_freq'),Label('GHz')), Item('cavity_quality'), HGroup(Item('dqd_t'),Label('GHz')), HGroup(Item('dqd_coupling'),Label('MHz')), HGroup(Item('dqd_delta'),Label('GHz')), HGroup(Item('dqd_gamma12'),Label('GHz')), HGroup(Item('dqd_gamma13'),Label('GHz')), HGroup(Item('lande', style = 'readonly'), Label('GHz/T')), ), UItem('plot_g', editor = ComponentEditor()), ), ), resizable = True, ) def __init__(self): super(ElectrodeHysteresis,self).__init__() self.new_value1(self.easy_axis_angle_input1) self.new_value2(self.easy_axis_angle_input2) self.data1 = ArrayPlotData() self.plot1 = Plot(self.data1,title = 'Electrode 1') self.plot1.padding = (80,10,20,40) self.plot1.index_axis.title = 'Field (mT)' self.plot1.value_axis.title = 'Orientation (rad)' self.data2 = ArrayPlotData() self.plot2 = Plot(self.data2,title = 'Electrode 2') self.plot2.padding = (80,10,20,40) self.plot2.index_axis.title = 'Field (mT)' self.plot2.value_axis.title = 'Orientation (rad)' self.data_diff = ArrayPlotData() self.plot_diff = Plot(self.data_diff,title = 'Current') self.plot_diff.padding = (80,10,20,40) self.plot_diff.index_axis.title = 'Field (mT)' self.plot_diff.value_axis.title = 'Current (a.u.)' self.data_g = ArrayPlotData() self.plot_g = Plot(self.data_g,title = 'DQD response') self.plot_g.padding = (80,10,20,40) self.plot_g.index_axis.title = 'Field (mT)' self.plot_g.value_axis.title = 'Phase contrast (rad)' dummy_x = linspace(0,10,1000) dummy_y1 = 0*dummy_x dummy_y2 = 0*dummy_x self.data1.set_data('x',dummy_x) self.data1.set_data('y1',dummy_y1) self.data1.set_data('y2',dummy_y2) self.plot1.plot(('x','y1'),color = 'blue') self.plot1.plot(('x','y2'),color = 'red') self.data2.set_data('x',dummy_x) self.data2.set_data('y1',dummy_y1) self.data2.set_data('y2',dummy_y2) self.plot2.plot(('x','y1'),color = 'blue') self.plot2.plot(('x','y2'),color = 'red') self.data_diff.set_data('x',dummy_x) self.data_diff.set_data('y1',dummy_y1) self.data_diff.set_data('y2',dummy_y2) self.plot_diff.plot(('x','y1'),color = 'blue') self.plot_diff.plot(('x','y2'),color = 'red') self.data_g.set_data('x',dummy_x) self.data_g.set_data('y1',dummy_y1) self.data_g.set_data('y2',dummy_y2) self.plot_g.plot(('x','y1'),color = 'blue') self.plot_g.plot(('x','y2'),color = 'red') def _compute_fired(self): self.data1.set_data('x', linspace(self.field_min,self.field_max,self.points)) self.data2.set_data('x', linspace(self.field_min,self.field_max,self.points)) self.data_diff.set_data('x', linspace(self.field_min,self.field_max,self.points)) self.data_g.set_data('x', linspace(self.field_min,self.field_max,self.points)) inv_field_coercion1 = 1/self.coercitive_field1 phi1 = 0 phi1_incr = [] phi1_decr = [] inv_field_coercion2 = 1/self.coercitive_field2 phi2 = 0 phi2_incr = [] phi2_decr = [] for field in linspace(0, self.field_max, self.points/2): def min_stoner1(phi): return stoner_free_energy(phi, self.easy_axis_angle1,field,inv_field_coercion1) def min_stoner2(phi): return stoner_free_energy(phi, self.easy_axis_angle2,field,inv_field_coercion2) phi1 = optimize.fmin(min_stoner1,[phi1],disp = False) phi2 = optimize.fmin(min_stoner2,[phi2],disp = False) del min_stoner1 del min_stoner2 for field in linspace(self.field_max, self.field_min, self.points): def min_stoner1(phi): return stoner_free_energy(phi, self.easy_axis_angle1,field,inv_field_coercion1) def min_stoner2(phi): return stoner_free_energy(phi, self.easy_axis_angle2,field,inv_field_coercion2) phi1 = optimize.fmin(min_stoner1,[phi1],disp = False) phi1_decr.append(float(phi1)%(2*Pi)) phi2 = optimize.fmin(min_stoner2,[phi2],disp = False) phi2_decr.append(float(phi2)%(2*Pi)) del min_stoner1 del min_stoner2 for field in linspace(self.field_min, self.field_max, self.points): def min_stoner1(phi): return stoner_free_energy(phi, self.easy_axis_angle1,field,inv_field_coercion1) def min_stoner2(phi): return stoner_free_energy(phi, self.easy_axis_angle2,field,inv_field_coercion2) phi1 = optimize.fmin(min_stoner1,[phi1],disp = False) phi1_incr.append(float(phi1)%(2*Pi)) phi2 = optimize.fmin(min_stoner2,[phi2],disp = False) phi2_incr.append(float(phi2)%(2*Pi)) del min_stoner1 del min_stoner2 phi1_decr.reverse() phi2_decr.reverse() phi_diff_incr = array(phi1_incr)-array(phi2_incr) phi_diff_decr = array(phi1_decr)-array(phi2_decr) self.data1.set_data('y1',phi1_incr) self.data1.set_data('y2',phi1_decr) self.plot1.value_range.low_setting = 'auto' self.plot1.value_range.high_setting = 'auto' self.data2.set_data('y1',phi2_incr) self.data2.set_data('y2',phi2_decr) self.plot2.value_range.low_setting = 'auto' self.plot2.value_range.high_setting = 'auto' self.data_diff.set_data('y1',cos(phi_diff_incr)) self.data_diff.set_data('y2',cos(phi_diff_decr)) self.plot_diff.value_range.low_setting = 'auto' self.plot_diff.value_range.high_setting = 'auto' self.data_g.set_data('aux1', mod(phi_diff_incr, Pi)) self.data_g.set_data('aux2', mod(phi_diff_decr, Pi)) self.dqd_ready = True self.compute_dqd_answer() @on_trait_change('dqd_t, dqd_delta, dqd_gamma12, dqd_gamma13,dqd_coupling,\ dqd_ready') def compute_dqd_answer(self): if self.dqd_ready: mag_field = self.lande*self.data_g.get_data('x') delta = self.dqd_delta + mag_field theta1 = self.data_g.get_data('aux1') theta2 = self.data_g.get_data('aux2') ener1_1 = -energy_dqd(self.dqd_t, delta, theta1, +1) ener1_2 = -energy_dqd(self.dqd_t, delta, theta2, +1) ener2_1 = -energy_dqd(self.dqd_t, delta, theta1, -1) ener2_2 = -energy_dqd(self.dqd_t, delta, theta2, -1) ener3_1 = energy_dqd(self.dqd_t, delta, theta1, -1) ener3_2 = energy_dqd(self.dqd_t, delta, theta2, -1) aux1 = delta*sin(theta1/2) aux2 = delta*sin(theta2/2) kmm_1 = kappa(self.dqd_t, delta, theta1,-1,-1) kmm_2 = kappa(self.dqd_t, delta, theta2,-1,-1) kpm_1 = kappa(self.dqd_t, delta, theta1,+1,-1) kpm_2 = kappa(self.dqd_t, delta, theta2,+1,-1) kmp_1 = kappa(self.dqd_t, delta, theta1,-1,+1) kmp_2 = kappa(self.dqd_t, delta, theta2,-1,+1) g12_1 = (aux1*kmm_1 + aux1*kpm_1)/\ sqrt((aux1**2+kmm_1**2)*(aux1**2+kpm_1**2)) g12_2 = (aux2*kmm_2 + aux2*kpm_2)/\ sqrt((aux2**2+kmm_2**2)*(aux2**2+kpm_2**2)) g13_1 = (aux1*kpm_1 - aux1*kmp_1)/\ sqrt((aux1**2+kmp_1**2)*(aux1**2+kpm_1**2)) g13_2 = (aux2*kpm_2 - aux2*kmp_2)/\ sqrt((aux2**2+kmp_2**2)*(aux2**2+kpm_2**2)) detun12_1 = ener2_1-ener1_1-self.cavity_freq detun12_2 = ener2_2-ener1_2-self.cavity_freq detun13_1 = ener3_1-ener1_1-self.cavity_freq detun13_2 = ener3_2-ener1_2-self.cavity_freq dqd_response_1 = detun12_1*g12_1**2/\ (detun12_1**2 + self.dqd_gamma12**2) +\ detun13_1*g13_1**2/\ (detun13_1**2 + self.dqd_gamma13**2) dqd_response_2 = detun12_2*g12_2**2/\ (detun12_2**2 + self.dqd_gamma12**2) +\ detun13_2*g13_2**2/\ (detun13_2**2 +self.dqd_gamma13**2) dqd_response_1 *= 2*self.cavity_quality/self.cavity_freq*\ self.dqd_coupling**2/1000000 dqd_response_2 *= 2*self.cavity_quality/self.cavity_freq*\ self.dqd_coupling**2/1000000 self.data_g.set_data('y1', dqd_response_1) self.data_g.set_data('y2', dqd_response_2) @on_trait_change('easy_axis_angle_input1') def new_value1(self,new): self.easy_axis_angle1 = eval(new) @on_trait_change('easy_axis_angle_input2') def new_value2(self,new): self.easy_axis_angle2 = eval(new)
class TwoDimensionalPlot(ChacoPlot): """ A 2D plot. """ auto_color_idx = 0 auto_color_list = ['green', 'brown', 'blue', 'red', 'black'] @classmethod def auto_color(cls): """ Choose the next color. """ color = cls.auto_color_list[cls.auto_color_idx] cls.auto_color_idx = (cls.auto_color_idx + 1) % len(cls.auto_color_list) return color def __init__(self, parent, color=None, *args, **kwargs): self.parent = parent if color is None: color = self.auto_color() self.data = ArrayPlotData() self.data.set_data('x', [0]) self.data.set_data('y', [0]) ChacoPlot.__init__(self, self.data, *args, **kwargs) self.plot(('x', 'y'), color=color) self.configure() @property def control(self): """ A drawable control. """ return Window(self.parent, component=self).control def get_data(self, axis): """ Values for an axis. """ return self.data.get_data(axis) def set_data(self, values, axis): self.data.set_data(axis, values) x_data = property(partial(get_data, axis='x'), partial(set_data, axis='x')) y_data = property(partial(get_data, axis='y'), partial(set_data, axis='y')) def x_autoscale(self): """ Enable autoscaling for the x axis. """ x_range = self.plots.values()[0][0].index_mapper.range x_range.low = x_range.high = 'auto' def y_autoscale(self): """ Enable autoscaling for the y axis. """ y_range = self.plots.values()[0][0].value_mapper.range y_range.low = y_range.high = 'auto'
class StdXYPlotFactory(BasePlotFactory): """ Factory to create a 2D plot with one of more renderers of the same kind """ #: Generated chaco plot containing all requested renderers plot = Instance(MultiMapperPlot) #: List of plot_data keys to plot in pairs, one pair per renderer renderer_desc = List(Dict) #: Renderer list, mapped to their name renderers = Dict #: Optional legend object to be added to the future plot legend = Instance(Legend) def __init__(self, x_arr=None, y_arr=None, z_arr=None, hover_data=None, **traits): super(StdXYPlotFactory, self).__init__(**traits) if isinstance(x_arr, pd.Series): x_arr = x_arr.values if isinstance(y_arr, pd.Series): y_arr = y_arr.values if isinstance(y_arr, pd.Series): z_arr = z_arr.values if hover_data is None: hover_data = {} if self.plot_data is None: self.initialize_plot_data(x_arr=x_arr, y_arr=y_arr, z_arr=z_arr, **hover_data) self.adjust_plot_style() def adjust_plot_style(self): """ Translate general plotting style info into xy plot parameters. """ pass def initialize_plot_data(self, x_arr=None, y_arr=None, z_arr=None, **adtl_arrays): """ Set the plot_data and the list of renderer descriptions. If the data arrays are dictionaries rather than straight arrays, they describe multiple renderers. """ if x_arr is None or y_arr is None: msg = "2D plots require a valid plot_data or an array for both x" \ " and y." logger.exception(msg) raise ValueError(msg) if isinstance(x_arr, np.ndarray): data_map = self._plot_data_single_renderer(x_arr, y_arr, z_arr, **adtl_arrays) elif isinstance(x_arr, dict): assert set(x_arr.keys()) == set(y_arr.keys()) data_map = self._plot_data_multi_renderer(x_arr, y_arr, z_arr, **adtl_arrays) else: msg = "x_arr/y_arr should be either an array or a dictionary " \ "mapping the z/hue value to the corresponding x array, but" \ " {} ({}) was passed." msg = msg.format(x_arr, type(x_arr)) raise ValueError(msg) self.plot_data = ArrayPlotData(**data_map) return data_map def _plot_data_single_renderer(self, x_arr=None, y_arr=None, z_arr=None, **adtl_arrays): """ Build the data_map to build the plot data. """ data_map = {self.x_col_name: x_arr, self.y_col_name: y_arr} data_map.update(adtl_arrays) renderer_data = { "x": self.x_col_name, "y": self.y_col_name, "name": DEFAULT_RENDERER_NAME } self.renderer_desc = [renderer_data] return data_map def _plot_data_multi_renderer(self, x_arr=None, y_arr=None, z_arr=None, **adtl_arrays): """ Built the data_map to build the plot data for multiple renderers. """ data_map = {} for i, hue_val in enumerate(sorted(x_arr.keys())): hue_name, x_name, y_name = self._add_arrays_for_hue( data_map, x_arr, y_arr, hue_val, i, adtl_arrays) renderer_data = {"x": x_name, "y": y_name, "name": hue_name} self._hue_values.append(hue_name) self.renderer_desc.append(renderer_data) return data_map def _add_arrays_for_hue(self, data_map, x_arr, y_arr, hue_val, hue_val_idx, adtl_arrays): """ Build and collect all arrays to add to ArrayPlotData for hue value. """ hue_name = str(hue_val) x_name = self._plotdata_array_key(self.x_col_name, hue_name) y_name = self._plotdata_array_key(self.y_col_name, hue_name) data_map[x_name] = x_arr[hue_val] data_map[y_name] = y_arr[hue_val] # Collect any additional dataset that needs to be stored (for # e.g. to feed plot tools) for adtl_col, col_data in adtl_arrays.items(): key = self._plotdata_array_key(adtl_col, hue_name) data_map[key] = col_data[hue_val] return hue_name, x_name, y_name def _plotdata_array_key(self, col_name, hue_name=""): """ Name of the ArrayPlotData containing the array from specified col. Parameters ---------- col_name : str Name of the column being displayed. hue_name : str Name of the renderer color the array will be used in. Typically the coloring column value, converted to string. """ return col_name + hue_name def generate_plot(self): """ Generate and return a dict containing a plot and its properties. """ plot = self.plot = MultiMapperPlot( **self.plot_style.container_style.to_traits()) # Emulate chaco.Plot interface: plot.data = self.plot_data self.add_renderers(plot) self.set_axis_labels(plot) if len(self.renderer_desc) > 1: self.set_legend(plot) self.add_tools(plot) # Build a description of the plot to build a PlotDescriptor desc = dict(plot_type=self.plot_type, plot=plot, visible=True, plot_title=self.plot_title, x_col_name=self.x_col_name, y_col_name=self.y_col_name, x_axis_title=self.x_axis_title, y_axis_title=self.y_axis_title, z_col_name=self.z_col_name, z_axis_title=self.z_axis_title, plot_factory=self) if self.plot_style.container_style.include_colorbar: self.generate_colorbar(desc) self.add_colorbar(desc) return desc def add_tools(self, plot): """ Add all tools specified in plot_tools list to provided plot. """ broadcaster = BroadcasterTool() # IMPORTANT: add the broadcast tool to one of the renderers, NOT the # container. Otherwise, the box zoom will crop wrong: first_plot = plot.components[0] first_plot.tools.append(broadcaster) for i, plot in enumerate(plot.components): if "pan" in self.plot_tools: pan = PanTool(plot) broadcaster.tools.append(pan) if "zoom" in self.plot_tools: # FIXME: the zoom tool is added to the broadcaster's tools # attribute because it doesn't have an overlay list. That # means the box plot mode won't display the blue box! zoom = ZoomTool(component=plot, zoom_factor=1.15) broadcaster.tools.append(zoom) if "legend" in self.plot_tools and self.legend: legend = self.legend legend.tools.append( LegendTool(component=self.legend, drag_button="right")) legend.tools.append(LegendHighlighter(component=legend)) if "context_menu" in self.plot_tools: self.context_menu_manager.target = self.plot menu = self.context_menu_manager.build_menu() context_menu = ContextMenuTool(component=self.plot, menu_manager=menu) self.plot.tools.append(context_menu) def add_renderers(self, plot): """ Add all renderers to provided plot container. """ styles = self.plot_style.renderer_styles if len(styles) != len(self.renderer_desc): msg = "Something went wrong: received {} styles and {} renderer " \ "descriptions.".format(len(styles), len(self.renderer_desc)) logger.exception(msg) raise ValueError(msg) for i, (desc, style) in enumerate(zip(self.renderer_desc, styles)): first_renderer = i == 0 self.add_renderer(plot, desc, style, first_renderer=first_renderer) self.align_all_renderers(plot) def add_renderer(self, plot, desc, style, first_renderer=False): """ Create and add to plot renderer described by desc and style. If the axis it is displayed along isn't already created, create it too, and add it to the plot's list of underlays. """ # Modify the renderer's style's name so it is displayed in the style # view: style.renderer_name = desc["name"] renderer = self._build_renderer(desc, style) plot.add(renderer) self.renderers[desc["name"]] = renderer if first_renderer: left_axis, bottom_axis = add_default_axes(renderer) # Emulate chaco.Plot interface: plot.x_axis = bottom_axis plot.y_axis = left_axis renderer.underlays = [] plot.underlays = [bottom_axis, left_axis] else: if style.orientation == STYLE_R_ORIENT and \ plot.second_y_axis is None: is_log = self.plot_style.second_y_axis_style.scaling == \ LOG_AXIS_STYLE if is_log: mapper_klass = LogMapper else: mapper_klass = LinearMapper # The range needs to be initialized to the axis can be aligned # with all secondary y axis renderers: mapper = mapper_klass(range=DataRange1D()) second_y_axis = PlotAxis(component=renderer, orientation="right", mapper=mapper) plot.second_y_axis = second_y_axis plot.underlays.append(second_y_axis) return renderer def align_all_renderers(self, plot): """ Align all renderers in index and value dimensions to plot's axis. This method is used to keep renderers aligned with the displayed axes once their ranges have been set. """ all_renderers = self.renderers.values() if len(all_renderers) <= 1: return styles = self.plot_style.renderer_styles align_renderers(all_renderers, plot.x_axis, dim="index") if plot.second_y_axis is not None: l_renderers = [ rend for rend, style in zip(all_renderers, styles) if style.orientation == STYLE_L_ORIENT ] r_renderers = [ rend for rend, style in zip(all_renderers, styles) if style.orientation == STYLE_R_ORIENT ] align_renderers(l_renderers, plot.y_axis, dim="value") align_renderers(r_renderers, plot.second_y_axis, dim="value") else: align_renderers(all_renderers, plot.y_axis, dim="value") def _build_renderer(self, desc, style): """ Invoke appropriate renderer factory to build and return renderer. """ renderer_maker = RENDERER_MAKER[style.renderer_type] x = self.plot_data.get_data(desc["x"]) y = self.plot_data.get_data(desc["y"]) if self.plot_style.x_axis_style.scaling == LOG_AXIS_STYLE: x_mapper_class = LogMapper else: x_mapper_class = LinearMapper if style.orientation == STYLE_L_ORIENT: y_style = self.plot_style.y_axis_style else: y_style = self.plot_style.second_y_axis_style if y_style.scaling == LOG_AXIS_STYLE: y_mapper_class = LogMapper else: y_mapper_class = LinearMapper return renderer_maker(data=(x, y), index_mapper_class=x_mapper_class, value_mapper_class=y_mapper_class, **style.to_plot_kwargs()) def set_legend(self, plot, align="ur", padding=10): """ Add legend and make it relocatable & clickable if tools requested. FIXME: Add control over legend labels. """ # Make sure plot list in legend doesn't include error bar renderers: # legend_labels = [desc["name"] for desc in self.renderer_desc] legend = Legend(component=plot, padding=padding, align=align, title=self.z_axis_title) legend.plots = self.renderers legend.visible = True plot.overlays.append(legend) # Emulate chaco.Plot-like behavior: self.legend = legend # Post creation renderer management methods ------------------------------- def update_renderers_from_data(self, removed=None): """ The plot_data was updated: update/remove existing renderers. """ if removed is None: removed = [] rend_desc_map = {} for desc in self.renderer_desc: rend_desc_map[desc["name"]] = desc rend_name_list = list(self.renderers.keys()) for name in rend_name_list: renderer = self.renderers[name] desc = rend_desc_map[name] both_removed = desc["x"] in removed and desc["y"] in removed one_removed = (desc["x"] in removed and desc["y"] not in removed) \ or (desc["x"] not in removed and desc["y"] in removed) if both_removed: self.remove_renderer(desc) elif one_removed: msg = "Unable to update the renderer {}: the data seems to be"\ " incomplete because x was set as removed and not y or" \ " vice versa. Removed keys: {}. Please report this " \ "issue.".format(desc["name"], removed) logger.exception(msg) raise ValueError(msg) else: x = self.plot_data.get_data(desc["x"]) y = self.plot_data.get_data(desc["y"]) renderer.index.set_data(x) renderer.value.set_data(y) def remove_renderer(self, rend_desc): """ Remove renderer described by provided descriptor from current plot. """ rend_name = rend_desc["name"] renderer = self.renderers.pop(rend_name) self.plot.remove(renderer) rend_idx = 0 for desc in self.renderer_desc: if desc["name"] == rend_name: self.renderer_desc.pop(rend_idx) self.plot_style.renderer_styles.pop(rend_idx) break rend_idx += 1 if self.legend: self.legend.plots.pop(rend_name) def append_new_renderers(self, desc_list, styles): """ Append new renderers to an existing factory plot. """ num_existing_renderers = len(self.renderer_desc) for i, (rend_desc, rend_style) in enumerate(zip(desc_list, styles)): rend_idx = num_existing_renderers + i renderer = self.add_renderer(self.plot, rend_desc, rend_style, first_renderer=rend_idx == 0) self.renderer_desc.append(rend_desc) self.plot_style.renderer_styles.append(rend_style) if self.legend: self.legend.plots[rend_desc["name"]] = renderer # Traits initialization methods ------------------------------------------- def _plot_tools_default(self): return {"zoom", "pan", "legend", "context_menu"}