class HistogramPlotHandler(HasTraits): ''' Class for handling the histograms. ''' # Index for the histogram plot index = Array # The selection handler object for the selected data selection_handler = Instance(SelectionHandler) # OVerlayPlotContainer for the histogram plot container = Instance(OverlayPlotContainer) # Number of bins of the histogram nbins = Int(10) # Whether the data is a pandas dataframe or a numpy array AS_PANDAS_DATAFRAME = Bool def __init__(self): self.index = range(self.nbins) self.selection_handler = SelectionHandler() self.container = OverlayPlotContainer() def draw_histogram(self): ''' Default function called when drawing the histogram. ''' for component in self.container.components: self.container.remove(component) self.selection_handler.create_selection() if len(self.selection_handler.selected_indices) == 1: tuple_list = self.selection_handler.selected_indices[0] if self.AS_PANDAS_DATAFRAME: column_name = self.data.columns[tuple_list[1]] y = self.data[column_name] self.index = np.arange(self.nbins) hist = np.histogram(y, self.nbins)[0] plotdata = ArrayPlotData(x=self.index, y=hist) plot = Plot(plotdata) plot.plot(("x", "y"), type='bar', bar_width=0.5) self.container.add(plot) else: column = tuple_list[1] y = self.data[:, column] self.index = np.arange(self.nbins) hist = np.histogram(y, self.nbins)[0] plotdata = ArrayPlotData(x=self.index, y=hist) plot = Plot(plotdata) plot.plot(("x", "y"), type='bar', bar_width=0.5) self.container.add(plot) self.container.request_redraw() self.selection_handler.flush()
class HistogramPlotHandler(HasTraits): """ Class for handling the histograms. """ # Index for the histogram plot index = Array # The selection handler object for the selected data selection_handler = Instance(SelectionHandler) # OVerlayPlotContainer for the histogram plot container = Instance(OverlayPlotContainer) # Number of bins of the histogram nbins = Int(10) # Whether the data is a pandas dataframe or a numpy array AS_PANDAS_DATAFRAME = Bool def __init__(self): self.index = range(self.nbins) self.selection_handler = SelectionHandler() self.container = OverlayPlotContainer() def draw_histogram(self): """ Default function called when drawing the histogram. """ for component in self.container.components: self.container.remove(component) self.selection_handler.create_selection() if len(self.selection_handler.selected_indices) == 1: tuple_list = self.selection_handler.selected_indices[0] if self.AS_PANDAS_DATAFRAME: column_name = self.data.columns[tuple_list[1]] y = self.data[column_name] self.index = np.arange(self.nbins) hist = np.histogram(y, self.nbins)[0] plotdata = ArrayPlotData(x=self.index, y=hist) plot = Plot(plotdata) plot.plot(("x", "y"), type="bar", bar_width=0.5) self.container.add(plot) else: column = tuple_list[1] y = self.data[:, column] self.index = np.arange(self.nbins) hist = np.histogram(y, self.nbins)[0] plotdata = ArrayPlotData(x=self.index, y=hist) plot = Plot(plotdata) plot.plot(("x", "y"), type="bar", bar_width=0.5) self.container.add(plot) self.container.request_redraw() self.selection_handler.flush()
class CellCropper(StackViewer): template = Array CC = Array peaks = List zero=Int(0) tmp_size = Range(low=2, high=512, value=64, cols=4) max_pos_x=Property(depends_on=['tmp_size']) max_pos_y=Property(depends_on=['tmp_size']) top = Range(low='zero',high='max_pos_y', value=20, cols=4) left = Range(low='zero',high='max_pos_x', value=20, cols=4) is_square = Bool tmp_plot = Instance(Plot) findpeaks = Button peak_width = Range(low=2, high=200, value=10) ShowCC = Bool numpeaks_total = Int(0,cols=5) numpeaks_img = Int(0,cols=5) OK_custom=OK_custom_handler thresh=Trait(None,None,List,Tuple,Array) thresh_upper=Float(1.0) thresh_lower=Float(-1.0) tmp_img_idx=Int(0) csr=Instance(BaseCursorTool) traits_view = View( Group( Group( Item("img_container",editor=ComponentEditor(), show_label=False), HGroup( Item("ShowCC", editor=BooleanEditor(), label="Show cross correlation image"), Spring(), Item("prev_img",editor=ButtonEditor(label="<"),show_label=False, enabled_when='numfiles > 1'), Item("next_img",editor=ButtonEditor(label=">"),show_label=False, enabled_when='numfiles > 1'), ), label="Original image", show_border=True, trait_modified="tab_selected", orientation='vertical',), VGroup( Group( HGroup( Item("left", label="Left coordinate", style="custom"), Spring(), Item("top", label="Top coordinate", style="custom"), ), Item("tmp_size", label="Template size", style="custom"), Item("tmp_plot",editor=ComponentEditor(height=256, width=256), show_label=False, resizable=True), label="Template", show_border=True), Group( HGroup( Item("peak_width", label="Peak width", style="custom"), Spring(), Item("findpeaks",editor=ButtonEditor(label="Find Peaks"),show_label=False), ), HGroup( Item("thresh_lower",label="Threshold Lower Value", editor=TextEditor(evaluate=float, format_str='%1.4f')), Spring(), Item("thresh_upper",label="Threshold Upper Value", editor=TextEditor(evaluate=float, format_str='%1.4f')), ), HGroup( Item("numpeaks_img",label="Number of Cells selected (this image)",style='readonly'), Spring(), Item("numpeaks_total",label="Total",style='readonly'), ), label="Peak parameters", show_border=True), ), orientation='horizontal'), buttons = [ Action(name='OK', enabled_when = 'numpeaks_total > 0' ), CancelButton ], title="Template Picker", handler=OK_custom, kind='livemodal', key_bindings = key_bindings, width=940, height=530,resizable=True) def __init__(self, controller, *args, **kw): super(CellCropper, self).__init__(controller, *args, **kw) try: import cv except: try: import cv2.cv as cv except: print "OpenCV unavailable. Can't do cross correlation without it. Aborting." return None self.OK_custom=OK_custom_handler() self.template = self.data[self.top:self.top+self.tmp_size,self.left:self.left+self.tmp_size] tmp_plot_data=ArrayPlotData(imagedata=self.template) tmp_plot=Plot(tmp_plot_data,default_origin="top left") tmp_plot.img_plot("imagedata", colormap=jet) tmp_plot.aspect_ratio=1.0 self.tmp_plot=tmp_plot self.tmp_plotdata=tmp_plot_data self.crop_sig=None def render_image(self): plot = super(CellCropper,self).render_image() img=plot.img_plot("imagedata", colormap=gray)[0] csr = CursorTool(img, drag_button='left', color='white', line_width=2.0) self.csr=csr csr.current_position=self.left, self.top img.overlays.append(csr) self.img_plot=plot return plot # TODO: use base class render_scatter_overlay method def render_scatplot(self): peakdata=ArrayPlotData() peakdata.set_data("index",self.peaks[self.img_idx][:,0]) peakdata.set_data("value",self.peaks[self.img_idx][:,1]) peakdata.set_data("color",self.peaks[self.img_idx][:,2]) scatplot=Plot(peakdata,aspect_ratio=self.img_plot.aspect_ratio,default_origin="top left") scatplot.plot(("index", "value", "color"), type="cmap_scatter", name="my_plot", color_mapper=jet(DataRange1D(low = 0.0, high = 1.0)), marker = "circle", fill_alpha = 0.5, marker_size = 6, ) scatplot.x_grid.visible = False scatplot.y_grid.visible = False scatplot.range2d=self.img_plot.range2d self.scatplot=scatplot self.peakdata=peakdata return scatplot def _image_plot_container(self): plot = self.render_image() # Create a container to position the plot and the colorbar side-by-side self.container=OverlayPlotContainer() self.container.add(plot) self.img_container = HPlotContainer(use_backbuffer = False) self.img_container.add(self.container) self.img_container.bgcolor = "white" #ipdb.set_trace() if self.numpeaks_img>0: scatplot = self.render_scatplot() self.container.add(scatplot) colorbar = self.draw_colorbar() self.img_container.add(colorbar) return self.img_container def draw_colorbar(self): scatplot=self.scatplot cmap_renderer = scatplot.plots["my_plot"][0] selection = ColormappedSelectionOverlay(cmap_renderer, fade_alpha=0.35, selection_type="range") cmap_renderer.overlays.append(selection) if self.thresh is not None: cmap_renderer.color_data.metadata['selections']=self.thresh cmap_renderer.color_data.metadata_changed={'selections':self.thresh} # Create the colorbar, handing in the appropriate range and colormap colormap=scatplot.color_mapper colorbar = ColorBar( index_mapper=LinearMapper( range=DataRange1D( low = -1.0, high = 1.0) ), orientation='v', resizable='v', width=30, padding=20) colorbar_selection=RangeSelection(component=colorbar) colorbar.tools.append(colorbar_selection) ovr=colorbar.overlays.append( RangeSelectionOverlay(component=colorbar, border_color="white", alpha=0.8, fill_color="lightgray", metadata_name='selections')) #ipshell('colorbar, colorbar_selection and ovr available:') self.cbar_selection=colorbar_selection self.cmap_renderer=cmap_renderer colorbar.plot = cmap_renderer colorbar.padding_top = scatplot.padding_top colorbar.padding_bottom = scatplot.padding_bottom self.colorbar=colorbar return colorbar @on_trait_change('ShowCC') def toggle_cc_view(self): if self.ShowCC: self.update_CC() grid_data_source = self.img_plot.range2d.sources[0] grid_data_source.set_data(np.arange(self.CC.shape[1]), np.arange(self.CC.shape[0])) else: self.img_plotdata.set_data("imagedata",self.data) self.redraw_plots() @on_trait_change("img_idx") def update_img_depth(self): """ TODO: We look up the index in the model - first get a list of files, then get the name of the file at the given index. """ super(CellCropper, self).update_img_depth() if self.ShowCC: self.update_CC() self.redraw_plots() def _get_max_pos_x(self): max_pos_x=self.data.shape[-1]-self.tmp_size-1 if max_pos_x>0: return max_pos_x else: return None def _get_max_pos_y(self): max_pos_y=self.data.shape[-2]-self.tmp_size-1 if max_pos_y>0: return max_pos_y else: return None @on_trait_change('left, top') def update_csr_position(self): if self.left>0: self.csr.current_position=self.left,self.top @on_trait_change('csr:current_position') def update_top_left(self): if self.csr.current_position[0]>0 or self.csr.current_position[1]>0: if self.csr.current_position[0]>self.max_pos_x: if self.csr.current_position[1]<self.max_pos_y: self.top=self.csr.current_position[1] else: self.csr.current_position=self.max_pos_x, self.max_pos_y elif self.csr.current_position[1]>self.max_pos_y: self.left,self.top=self.csr.current_position[0],self.max_pos_y else: self.left,self.top=self.csr.current_position @on_trait_change('left, top, tmp_size') def update_tmp_plot(self): self.template = self.data[self.top:self.top+self.tmp_size,self.left:self.left+self.tmp_size] self.tmp_plotdata.set_data("imagedata", self.template) grid_data_source = self.tmp_plot.range2d.sources[0] grid_data_source.set_data(np.arange(self.tmp_size), np.arange(self.tmp_size)) self.tmp_img_idx=self.img_idx if self.numpeaks_total>0: print "clearing peaks" self.peaks=[np.array([[0,0,-1]])] self.update_CC() return def update_CC(self): if self.ShowCC: self.CC = cv_funcs.xcorr(self.template, self.data) self.img_plotdata.set_data("imagedata",self.CC) @on_trait_change('cbar_selection:selection') def update_thresh(self): try: thresh=self.cbar_selection.selection self.thresh=thresh self.cmap_renderer.color_data.metadata['selections']=thresh self.thresh_lower=thresh[0] self.thresh_upper=thresh[1] #cmap_renderer.color_data.metadata['selection_masks']=self.thresh self.cmap_renderer.color_data.metadata_changed={'selections':thresh} self.container.request_redraw() self.img_container.request_redraw() except: pass @on_trait_change('thresh_upper,thresh_lower') def manual_thresh_update(self): self.thresh=[self.thresh_lower,self.thresh_upper] self.cmap_renderer.color_data.metadata['selections']=self.thresh self.cmap_renderer.color_data.metadata_changed={'selections':self.thresh} self.container.request_redraw() self.img_container.request_redraw() @on_trait_change('peaks,cbar_selection:selection,img_idx') def calc_numpeaks(self): try: thresh=self.cbar_selection.selection self.thresh=thresh except: thresh=[] if thresh==[] or thresh==() or thresh==None: thresh=(-1,1) self.numpeaks_total=int(np.sum([np.sum(np.ma.masked_inside(self.peaks[i][:,2],thresh[0],thresh[1]).mask) for i in xrange(len(self.peaks))])) try: self.numpeaks_img=int(np.sum(np.ma.masked_inside(self.peaks[self.img_idx][:,2],thresh[0],thresh[1]).mask)) except: self.numpeaks_img=0 @on_trait_change('findpeaks') def locate_peaks(self): peaks=[] progress = ProgressDialog(title="Peak finder progress", message="Finding peaks on %s images"%self.numfiles, max=self.numfiles, show_time=True, can_cancel=False) progress.open() for idx in xrange(self.numfiles): self.controller.set_active_index(idx) self.data = self.controller.get_active_image()[:] self.CC = cv_funcs.xcorr(self.template, self.data) # peak finder needs peaks greater than 1. Multiply by 255 to scale them. pks=pc.two_dim_findpeaks(self.CC*255, peak_width=self.peak_width, medfilt_radius=None) pks[:,2]=pks[:,2]/255. peaks.append(pks) progress.update(idx+1) #ipdb.set_trace() self.peaks=peaks self.redraw_plots() def mask_peaks(self,idx): thresh=self.cbar_selection.selection if thresh==[]: thresh=(-1,1) mpeaks=np.ma.asarray(self.peaks[idx]) mpeaks[:,2]=np.ma.masked_outside(mpeaks[:,2],thresh[0],thresh[1]) return mpeaks @on_trait_change("peaks") def redraw_plots(self): oldplot=self.img_plot self.container.remove(oldplot) newplot=self.render_image() self.container.add(newplot) self.img_plot=newplot try: # if these haven't been created before, this will fail. wrap in try to prevent that. oldscat=self.scatplot self.container.remove(oldscat) oldcolorbar = self.colorbar self.img_container.remove(oldcolorbar) except: pass if self.numpeaks_img>0: newscat=self.render_scatplot() self.container.add(newscat) self.scatplot=newscat colorbar = self.draw_colorbar() self.img_container.add(colorbar) self.colorbar=colorbar self.container.request_redraw() self.img_container.request_redraw() def crop_cells(self): print "cropping cells..." for idx in xrange(self.numfiles): # filter the peaks that are outside the selected threshold self.controller.set_active_index(idx) self.data = self.controller.get_active_image() self.name = self.controller.get_active_name() peaks=np.ma.compress_rows(self.mask_peaks(idx)) tmp_sz=self.tmp_size data=np.zeros((peaks.shape[0],tmp_sz,tmp_sz)) if data.shape[0] >0: for i in xrange(peaks.shape[0]): # crop the cells from the given locations data[i,:,:]=self.data[peaks[i,1]:peaks[i,1]+tmp_sz, peaks[i,0]:peaks[i,0]+tmp_sz] # send the data to the controller for storage in the chest self.controller.add_cells(name = self.name, data = data, locations = peaks)
class KMeansPlotHandler(HasTraits): """ Class for plotting the k-means clusters. """ # the data to cluster data = Array # the dataset created after preprocessing dataset = Array # the sklearn.cluster.KMeans object kmeans = Instance(KMeans) # Number of clusters n_clusters = Int # Maximum iterations for the clustering algorithm max_iter = Int # Container for the cluster plots container = Instance(OverlayPlotContainer) # the columns from the dataset to omit when performing clustering to_omit = List def __init__(self): self.kmeans = KMeans() self.container = OverlayPlotContainer() def create_dataset(self): """ Creates a numpy array from the current selection to pass to the sklearn.cluster.kmeans object. """ if self.to_omit: if len(self.to_omit) > 0: n_rows = self.data.shape[0] n_cols = self.data.shape[1] to_omit = [] for elem in self.to_omit: if elem.isdigit(): to_omit.append(int(elem)) dataset = self.data[:, 0].reshape((n_rows, 1)) for elem in range(n_cols): if elem not in to_omit: if elem > 0: dataset = np.hstack((dataset, self.data[:, elem].reshape((n_rows, 1)))) self.dataset = dataset def plot_clusters(self): """ Plots the clusters after calling the .fit method of the sklearn kmeans estimator. """ self.kmeans.n_clusters = self.n_clusters self.kmeans.fit(self.dataset) # Reducing dimensions of the dataset and the cluster centers for # plottting pca = PCA(n_components=2, whiten=True) cluster_centers = pca.fit_transform(self.kmeans.cluster_centers_) dataset_red = pca.fit_transform(self.dataset) removed_components = [] for component in self.container.components: removed_components.append(component) for component in removed_components: self.container.remove(component) for i in range(self.n_clusters): current_indices = find(self.kmeans.labels_ == i) current_data = dataset_red[current_indices, :] plotdata = ArrayPlotData(x=current_data[:, 0], y=current_data[:, 1]) plot = Plot(plotdata) plot.plot(("x", "y"), type="scatter", color=tuple(COLOR_PALETTE[i])) self.container.add(plot) plotdata_cent = ArrayPlotData(x=cluster_centers[:, 0], y=cluster_centers[:, 1]) plot_cent = Plot(plotdata_cent) plot_cent.plot(("x", "y"), type="scatter", marker="cross", marker_size=8) self.container.add(plot_cent) self.container.request_redraw()
class XYPlotHandler(HasTraits): """ Class for handling XY plots """ # Whether the data is a pandas dataframe AS_PANDAS_DATAFRAME = Bool # The container for all current plots. Gets updated everytime a plot is # added. container = OverlayPlotContainer # This can be removed. plotdata = ArrayPlotData # The current Plot object. plot = Plot # ColorTrait, mainly required for the TraitsUIItem view. color = ColorTrait("blue") # Marker trait for the view marker = marker_trait # Marker size trait marker_size = Int(4) # An instance of SelectionHandler for adding plots from the current # selection. selection_handler = Instance(SelectionHandler) # Bool traits for checking the type of the plot (discrete / continuous) plot_type_disc = Bool plot_type_cont = Bool # The data from which to draw the plots, same as the table attribute of # CsvModel table = Array # The pandas data frame if AS_PANDAS_DATAFRAME data_frame = Instance(DataFrame) # Contains the grid underlays of all the current plots grid_underlays = List # Used for viewing the list of the plots and the legend plot_list_view = Dict # TraitsUI view for plot properties, yet to find an enaml implementation view = View(Item("color"), Item("marker"), Item("marker_size")) # Trait that defines whether tools are present. add_pan_tool = Bool add_zoom_tool = Bool add_dragzoom = Bool # Whether grids and axes are visible show_grid = Bool def __init__(self): self.selection_handler = SelectionHandler() self.container = OverlayPlotContainer() self.underlays = [] self.add_pan_tool = False self.add_zoom_tool = False self.add_dragzoom = False self.show_grid = False def add_xyplot_selection(self, plot_name): """ Called when the 'add plot from selection button is clicked.' """ self.selection_handler.create_selection() if self.selection_handler.xyplot_check(): if self.AS_PANDAS_DATAFRAME: x_column = self.data_frame.columns[self.selection_handler.selected_indices[0][1]] y_column = self.data_frame.columns[self.selection_handler.selected_indices[1][1]] x = np.array(self.data_frame[x_column]) y = np.array(self.data_frame[y_column]) self.plotdata = ArrayPlotData(x=x, y=y) else: first_column = self.selection_handler.selected_indices[0] second_column = self.selection_handler.selected_indices[1] self.plotdata = ArrayPlotData(x=self.table[:, first_column[1]], y=self.table[:, second_column[1]]) plot = Plot(self.plotdata) if self.plot_type_disc: plot_type = "scatter" else: plot_type = "line" plot.plot(("x", "y"), type=plot_type, color=self.color, marker=self.marker, marker_size=self.marker_size) self.plot = plot for underlay in self.plot.underlays: if isinstance(underlay, PlotGrid): if underlay not in self.grid_underlays: self.grid_underlays.append(underlay) for underlay in self.grid_underlays: if underlay in self.plot.underlays: self.plot.underlays.remove(underlay) if plot_name == "": self.plot_list_view["plot" + str(len(self.plot_list_view))] = self.plot else: self.plot_list_view[plot_name] = self.plot self.container.add(self.plot) self.container.request_redraw() self.selection_handler.flush() def grid_toggle(self, checked): """ Called when the 'Show Grid' checkbox ins toggled """ if not checked: for plot in self.container.components: for underlay in self.grid_underlays: if underlay in plot.underlays: plot.underlays.remove(underlay) else: for plot in self.container.components: for underlay in self.grid_underlays: if underlay not in plot.underlays: plot.underlays.append(underlay) self.container.request_redraw() def remove_selected_plots(self, selection): """ Called when the 'Remove Selected Plots' button is clicked """ remove_indices = [] for model_index in selection: remove_indices.append(model_index[0].row) remove_plots = [] for index in remove_indices: remove_plots.append(self.plot_list_view.keys()[index]) removed_plots = [] for plot in remove_plots: removed_plots.append(self.plot_list_view.pop(plot)) for plot in self.container.components: self.container.remove(plot) for plot in self.plot_list_view.keys(): self.container.add(self.plot_list_view[plot]) self.container.request_redraw() def edit_selection(self, show_grid, plot_visible, plot_type_disc): """ Called to start editing the selected plot. Should accompany the 'Edit Plot' dialog. """ # self.selection_handler.create_selection() # index = self.selection_handler.selected_indices[0][0] # plot_name = self.plot_list_view.keys()[index] # plot = self.plot_list_view[plot_name] self.container.remove(self.plot) self.plot_type_disc = plot_type_disc if self.plot_type_disc: plot_type = "scatter" else: plot_type = "line" plot = Plot(self.plot.data) plot.plot(("x", "y"), color=self.color, type=plot_type, marker=self.marker, marker_size=self.marker_size) self.plot = plot self.plot.visible = plot_visible grid_underlays = [] if not show_grid: for underlay in self.plot.underlays: if isinstance(underlay, PlotGrid): grid_underlays.append(underlay) for underlay in grid_underlays: self.plot.underlays.remove(underlay) self.container.add(self.plot) self.container.request_redraw() self.selection_handler.flush() def _add_pan_tool_changed(self): """ Method called when the Pan Tool checkbox is checked or unchecked. Adds the Pan Tool to the plot container if it isn't there and vice versa. """ broadcaster = BroadcasterTool() for plot in self.container.components: if self.add_pan_tool: pan = PanTool(plot) broadcaster.tools.append(pan) self.container.tools.append(broadcaster) else: for tool in self.container.tools: if isinstance(tool, BroadcasterTool): self.container.tools.remove(tool) def _add_zoom_tool_changed(self): """ Method called when the Zoom Tool checkbox is checked or unchecked. Adds the Zoom Tool to the plot container if it isn't there and vice versa. """ broadcaster = BroadcasterTool() for plot in self.container.components: if self.add_zoom_tool: pan = ZoomTool(plot) broadcaster.tools.append(pan) self.container.tools.append(broadcaster) else: for tool in self.container.tools: if isinstance(tool, BroadcasterTool): self.container.tools.remove(tool) def _add_dragzoom_changed(self): """ Method called when the Pan Tool checkbox is checked or unchecked. Adds the Pan Tool to the plot container if it isn't there and vice versa. """ broadcaster = BroadcasterTool() for plot in self.container.components: if self.add_dragzoom: pan = BetterSelectingZoom( plot, always_on=True, tool_mode="box", drag_button="left", color="lightskyblue", alpha=0.4, border_color="dodgerblue", ) broadcaster.tools.append(pan) self.container.tools.append(broadcaster) else: for tool in self.container.tools: if isinstance(tool, BroadcasterTool): self.container.tools.remove(tool) def _show_grid_changed(self): """ Called when the Show grid checkbox is checked or unchecked. Adds a grid if one is not present and removes if present. """ if not self.show_grid: for plot in self.container.components: for underlay in self.grid_underlays: if underlay in plot.underlays: plot.underlays.remove(underlay) else: for plot in self.container.components: for underlay in self.grid_underlays: if underlay not in plot.underlays: plot.underlays.append(underlay) self.container.request_redraw() def reassign_current_plot(self): """ Reassigns the currently selected plot. """ self.selection_handler.create_selection() plot_index = self.selection_handler.selected_indices[0][0] plot_name = self.plot_list_view.keys()[plot_index] self.plot = self.plot_list_view[plot_name] self.selection_handler.flush()
class StackedPlot(ChacoPlot): offset = Range(0.0, 1.0, 0.015) value_range = Range(0.01, 1.05, 1.00) flip_order = Bool(False) def _get_traits_group(self): return VGroup( HGroup(Item("flip_order"), Item("offset"), Item("value_range")), UItem("component", editor=ComponentEditor()), ) def __init__(self): super(StackedPlot, self).__init__() self.container = OverlayPlotContainer( bgcolor="white", use_backbuffer=True, border_visible=True, padding=50, padding_left=110, fill_padding=True ) self.data = ArrayPlotData() self.chaco_plot = None self.value_mapper = None self.index_mapper = None self.x_axis = PlotAxis( component=self.container, orientation="bottom", title=u"Angle (2\u0398)", title_font=settings.axis_title_font, tick_label_font=settings.tick_font, ) y_axis_title = "Normalized intensity (%s)" % get_value_scale_label("linear") self.y_axis = PlotAxis( component=self.container, orientation="left", title=y_axis_title, title_font=settings.axis_title_font, tick_label_font=settings.tick_font, ) self.container.overlays.extend([self.x_axis, self.y_axis]) self.container.tools.append(TraitsTool(self.container, classes=[LinePlot, PlotAxis])) self.colors = [] self.last_flip_order = self.flip_order @on_trait_change("offset, value_range, flip_order") def _replot_data(self): self._plot(self.data_x, None, self.data_z, self.scale) self.container.request_redraw() def _prepare_data(self, datasets): interpolate = True stack = stack_datasets(datasets) if interpolate: (x, z) = interpolate_datasets(stack, points=4800) x = array([x] * len(datasets)) else: x, z = map(np.transpose, np.transpose(stack)) return x, None, z def _plot(self, x, y, z, scale): self.data_x, self.data_z, self.scale = x, z, scale if self.container.components: self.colors = map(lambda plot: plot.color, self.container.components) if self.last_flip_order != self.flip_order: self.colors.reverse() self.container.remove(*self.container.components) # Use a custom renderer so plot lines are clickable self.chaco_plot = Plot(self.data, renderer_map={"line": ClickableLinePlot}) self.chaco_plot.bgcolor = "white" self.value_mapper = None self.index_mapper = None if len(self.data_x) == len(self.colors): colors = self.colors[:] else: colors = ["black"] * len(self.data_x) if self.flip_order: z = z[::-1] spacing = (z.max(axis=1) - z.min(axis=1)).min() * self.value_range offset = spacing * self.offset for i, (x_row, z_row) in enumerate(zip(x, z)): self.data.set_data("data_x_" + str(i), x_row) self.data.set_data("data_y_offset_" + str(i), z_row * self.value_range + offset * i) plots = self.chaco_plot.plot(("data_x_" + str(i), "data_y_offset_" + str(i)), color=colors[i], type="line") plot = plots[0] self.container.add(plot) # Required for double-clicking plots plot.index.sort_order = "ascending" plot.value.sort_order = "ascending" if self.value_mapper is None: self.index_mapper = plot.index_mapper self.value_mapper = plot.value_mapper else: plot.value_mapper = self.value_mapper self.value_mapper.range.add(plot.value) plot.index_mapper = self.index_mapper self.index_mapper.range.add(plot.index) range = self.value_mapper.range range.high = (range.high - range.low) * self.value_range + range.low self.x_axis.mapper = self.index_mapper self.y_axis.mapper = self.value_mapper self.y_axis.title = "Normalized intensity (%s)" % get_value_scale_label(scale) self.zoom_tool = ClickUndoZoomTool( plot, tool_mode="box", always_on=True, pointer="cross", drag_button=settings.zoom_button, undo_button=settings.undo_button, ) plot.overlays.append(self.zoom_tool) self.last_flip_order = self.flip_order return self.container def _reset_view(self): self.zoom_tool.revert_history_all()
class KMeansPlotHandler(HasTraits): ''' Class for plotting the k-means clusters. ''' # the data to cluster data = Array # the dataset created after preprocessing dataset = Array # the sklearn.cluster.KMeans object kmeans = Instance(KMeans) # Number of clusters n_clusters = Int # Maximum iterations for the clustering algorithm max_iter = Int # Container for the cluster plots container = Instance(OverlayPlotContainer) # the columns from the dataset to omit when performing clustering to_omit = List def __init__(self): self.kmeans = KMeans() self.container = OverlayPlotContainer() def create_dataset(self): ''' Creates a numpy array from the current selection to pass to the sklearn.cluster.kmeans object. ''' if self.to_omit: if len(self.to_omit) > 0: n_rows = self.data.shape[0] n_cols = self.data.shape[1] to_omit = [] for elem in self.to_omit: if elem.isdigit(): to_omit.append(int(elem)) dataset = self.data[:, 0].reshape((n_rows, 1)) for elem in range(n_cols): if elem not in to_omit: if elem > 0: dataset = np.hstack( (dataset, self.data[:, elem].reshape( (n_rows, 1)))) self.dataset = dataset def plot_clusters(self): ''' Plots the clusters after calling the .fit method of the sklearn kmeans estimator. ''' self.kmeans.n_clusters = self.n_clusters self.kmeans.fit(self.dataset) # Reducing dimensions of the dataset and the cluster centers for # plottting pca = PCA(n_components=2, whiten=True) cluster_centers = pca.fit_transform(self.kmeans.cluster_centers_) dataset_red = pca.fit_transform(self.dataset) removed_components = [] for component in self.container.components: removed_components.append(component) for component in removed_components: self.container.remove(component) for i in range(self.n_clusters): current_indices = find(self.kmeans.labels_ == i) current_data = dataset_red[current_indices, :] plotdata = ArrayPlotData(x=current_data[:, 0], y=current_data[:, 1]) plot = Plot(plotdata) plot.plot(("x", "y"), type='scatter', color=tuple(COLOR_PALETTE[i])) self.container.add(plot) plotdata_cent = ArrayPlotData(x=cluster_centers[:, 0], y=cluster_centers[:, 1]) plot_cent = Plot(plotdata_cent) plot_cent.plot(("x", "y"), type='scatter', marker='cross', marker_size=8) self.container.add(plot_cent) self.container.request_redraw()
class XYPlotHandler(HasTraits): ''' Class for handling XY plots ''' # Whether the data is a pandas dataframe AS_PANDAS_DATAFRAME = Bool # The container for all current plots. Gets updated everytime a plot is # added. container = OverlayPlotContainer # This can be removed. plotdata = ArrayPlotData # The current Plot object. plot = Plot # ColorTrait, mainly required for the TraitsUIItem view. color = ColorTrait("blue") # Marker trait for the view marker = marker_trait # Marker size trait marker_size = Int(4) # An instance of SelectionHandler for adding plots from the current # selection. selection_handler = Instance(SelectionHandler) # Bool traits for checking the type of the plot (discrete / continuous) plot_type_disc = Bool plot_type_cont = Bool # The data from which to draw the plots, same as the table attribute of # CsvModel table = Array # The pandas data frame if AS_PANDAS_DATAFRAME data_frame = Instance(DataFrame) # Contains the grid underlays of all the current plots grid_underlays = List # Used for viewing the list of the plots and the legend plot_list_view = Dict # TraitsUI view for plot properties, yet to find an enaml implementation view = View(Item('color'), Item('marker'), Item('marker_size')) # Trait that defines whether tools are present. add_pan_tool = Bool add_zoom_tool = Bool add_dragzoom = Bool # Whether grids and axes are visible show_grid = Bool def __init__(self): self.selection_handler = SelectionHandler() self.container = OverlayPlotContainer() self.underlays = [] self.add_pan_tool = False self.add_zoom_tool = False self.add_dragzoom = False self.show_grid = False def add_xyplot_selection(self, plot_name): ''' Called when the 'add plot from selection button is clicked.' ''' self.selection_handler.create_selection() if self.selection_handler.xyplot_check(): if self.AS_PANDAS_DATAFRAME: x_column = self.data_frame.columns[ self.selection_handler.selected_indices[0][1]] y_column = self.data_frame.columns[ self.selection_handler.selected_indices[1][1]] x = np.array(self.data_frame[x_column]) y = np.array(self.data_frame[y_column]) self.plotdata = ArrayPlotData(x=x, y=y) else: first_column = self.selection_handler.selected_indices[0] second_column = self.selection_handler.selected_indices[1] self.plotdata = ArrayPlotData(x=self.table[:, first_column[1]], y=self.table[:, second_column[1]]) plot = Plot(self.plotdata) if self.plot_type_disc: plot_type = 'scatter' else: plot_type = 'line' plot.plot(("x", "y"), type=plot_type, color=self.color, marker=self.marker, marker_size=self.marker_size) self.plot = plot for underlay in self.plot.underlays: if isinstance(underlay, PlotGrid): if underlay not in self.grid_underlays: self.grid_underlays.append(underlay) for underlay in self.grid_underlays: if underlay in self.plot.underlays: self.plot.underlays.remove(underlay) if plot_name == '': self.plot_list_view['plot' + str(len(self.plot_list_view))] = self.plot else: self.plot_list_view[plot_name] = self.plot self.container.add(self.plot) self.container.request_redraw() self.selection_handler.flush() def grid_toggle(self, checked): ''' Called when the 'Show Grid' checkbox ins toggled ''' if not checked: for plot in self.container.components: for underlay in self.grid_underlays: if underlay in plot.underlays: plot.underlays.remove(underlay) else: for plot in self.container.components: for underlay in self.grid_underlays: if underlay not in plot.underlays: plot.underlays.append(underlay) self.container.request_redraw() def remove_selected_plots(self, selection): ''' Called when the 'Remove Selected Plots' button is clicked ''' remove_indices = [] for model_index in selection: remove_indices.append(model_index[0].row) remove_plots = [] for index in remove_indices: remove_plots.append(self.plot_list_view.keys()[index]) removed_plots = [] for plot in remove_plots: removed_plots.append(self.plot_list_view.pop(plot)) for plot in self.container.components: self.container.remove(plot) for plot in self.plot_list_view.keys(): self.container.add(self.plot_list_view[plot]) self.container.request_redraw() def edit_selection(self, show_grid, plot_visible, plot_type_disc): ''' Called to start editing the selected plot. Should accompany the 'Edit Plot' dialog. ''' #self.selection_handler.create_selection() #index = self.selection_handler.selected_indices[0][0] #plot_name = self.plot_list_view.keys()[index] #plot = self.plot_list_view[plot_name] self.container.remove(self.plot) self.plot_type_disc = plot_type_disc if self.plot_type_disc: plot_type = 'scatter' else: plot_type = 'line' plot = Plot(self.plot.data) plot.plot(("x", "y"), color=self.color, type=plot_type, marker=self.marker, marker_size=self.marker_size) self.plot = plot self.plot.visible = plot_visible grid_underlays = [] if not show_grid: for underlay in self.plot.underlays: if isinstance(underlay, PlotGrid): grid_underlays.append(underlay) for underlay in grid_underlays: self.plot.underlays.remove(underlay) self.container.add(self.plot) self.container.request_redraw() self.selection_handler.flush() def _add_pan_tool_changed(self): ''' Method called when the Pan Tool checkbox is checked or unchecked. Adds the Pan Tool to the plot container if it isn't there and vice versa. ''' broadcaster = BroadcasterTool() for plot in self.container.components: if self.add_pan_tool: pan = PanTool(plot) broadcaster.tools.append(pan) self.container.tools.append(broadcaster) else: for tool in self.container.tools: if isinstance(tool, BroadcasterTool): self.container.tools.remove(tool) def _add_zoom_tool_changed(self): ''' Method called when the Zoom Tool checkbox is checked or unchecked. Adds the Zoom Tool to the plot container if it isn't there and vice versa. ''' broadcaster = BroadcasterTool() for plot in self.container.components: if self.add_zoom_tool: pan = ZoomTool(plot) broadcaster.tools.append(pan) self.container.tools.append(broadcaster) else: for tool in self.container.tools: if isinstance(tool, BroadcasterTool): self.container.tools.remove(tool) def _add_dragzoom_changed(self): ''' Method called when the Pan Tool checkbox is checked or unchecked. Adds the Pan Tool to the plot container if it isn't there and vice versa. ''' broadcaster = BroadcasterTool() for plot in self.container.components: if self.add_dragzoom: pan = BetterSelectingZoom(plot, always_on=True, tool_mode='box', drag_button='left', color='lightskyblue', alpha=0.4, border_color='dodgerblue') broadcaster.tools.append(pan) self.container.tools.append(broadcaster) else: for tool in self.container.tools: if isinstance(tool, BroadcasterTool): self.container.tools.remove(tool) def _show_grid_changed(self): ''' Called when the Show grid checkbox is checked or unchecked. Adds a grid if one is not present and removes if present. ''' if not self.show_grid: for plot in self.container.components: for underlay in self.grid_underlays: if underlay in plot.underlays: plot.underlays.remove(underlay) else: for plot in self.container.components: for underlay in self.grid_underlays: if underlay not in plot.underlays: plot.underlays.append(underlay) self.container.request_redraw() def reassign_current_plot(self): ''' Reassigns the currently selected plot. ''' self.selection_handler.create_selection() plot_index = self.selection_handler.selected_indices[0][0] plot_name = self.plot_list_view.keys()[plot_index] self.plot = self.plot_list_view[plot_name] self.selection_handler.flush()
class StackedPlot(ChacoPlot): offset = Range(0.0, 1.0, 0.015) value_range = Range(0.01, 1.05, 1.00) flip_order = Bool(False) def _get_traits_group(self): return VGroup( HGroup( Item('flip_order'), Item('offset'), Item('value_range'), ), UItem('component', editor=ComponentEditor()), ) def __init__(self): super(StackedPlot, self).__init__() self.container = OverlayPlotContainer(bgcolor='white', use_backbuffer=True, border_visible=True, padding=50, padding_left=110, fill_padding=True ) self.data = ArrayPlotData() self.chaco_plot = None self.value_mapper = None self.index_mapper = None self.x_axis = MyPlotAxis(component=self.container, orientation='bottom', title=u'Angle (2\u0398)', title_font=settings.axis_title_font, tick_label_font=settings.tick_font) y_axis_title = 'Normalized intensity (%s)' % get_value_scale_label('linear') self.y_axis = MyPlotAxis(component=self.container, orientation='left', title=y_axis_title, title_font=settings.axis_title_font, tick_label_font=settings.tick_font) self.container.overlays.extend([self.x_axis, self.y_axis]) self.container.tools.append( TraitsTool(self.container, classes=[LinePlot,MyPlotAxis])) self.colors = [] self.last_flip_order = self.flip_order @on_trait_change('offset, value_range, flip_order') def _replot_data(self): self._plot(self.data_x, None, self.data_z, self.scale) self.container.request_redraw() # def _prepare_data(self, datasets): def _prepare_data(self, stack): # stack = stack_datasets(datasets) x = stack[:,:,0] z = stack[:,:,2] return x, None, z def _plot(self, x, y, z, scale): self.data_x, self.data_z, self.scale = x, z, scale if self.container.components: self.colors = map(lambda plot: plot.color, self.container.components) if self.last_flip_order != self.flip_order: self.colors.reverse() self.container.remove(*self.container.components) # Use a custom renderer so plot lines are clickable self.chaco_plot = Plot(self.data, renderer_map={ 'line': ClickableLinePlot }) self.chaco_plot.bgcolor = 'white' self.value_mapper = None self.index_mapper = None if len(self.data_x) == len(self.colors): colors = self.colors[:] else: colors = ['black'] * len(self.data_x) if self.flip_order: z = z[::-1] spacing = (z.max(axis=1) - z.min(axis=1)).min() * self.value_range offset = spacing * self.offset for i, (x_row, z_row) in enumerate(zip(x, z)): self.data.set_data('data_x_' + str(i), x_row) self.data.set_data('data_y_offset_' + str(i), z_row * self.value_range + offset * i) plots = self.chaco_plot.plot(('data_x_' + str(i), 'data_y_offset_' + str(i)), color=colors[i], type='line') plot = plots[0] self.container.add(plot) # Required for double-clicking plots plot.index.sort_order = 'ascending' plot.value.sort_order = 'ascending' if self.value_mapper is None: self.index_mapper = plot.index_mapper self.value_mapper = plot.value_mapper else: plot.value_mapper = self.value_mapper self.value_mapper.range.add(plot.value) plot.index_mapper = self.index_mapper self.index_mapper.range.add(plot.index) range = self.value_mapper.range range.high = (range.high - range.low) * self.value_range + range.low self.x_axis.mapper = self.index_mapper self.y_axis.mapper = self.value_mapper self.y_axis.title = 'Normalized intensity (%s)' % \ get_value_scale_label(scale) self.zoom_tool = ClickUndoZoomTool( plot, tool_mode="box", always_on=True, pointer="cross", drag_button=settings.zoom_button, undo_button=settings.undo_button, ) plot.overlays.append(self.zoom_tool) self.last_flip_order = self.flip_order return self.container def _reset_view(self): self.zoom_tool.revert_history_all()
class TemplatePicker(HasTraits): template = Array CC = Array peaks = List zero=Int(0) tmp_size = Range(low=2, high=512, value=64, cols=4) max_pos_x=Property(depends_on=['tmp_size']) max_pos_y=Property(depends_on=['tmp_size']) top = Range(low='zero',high='max_pos_y', value=20, cols=4) left = Range(low='zero',high='max_pos_x', value=20, cols=4) is_square = Bool img_plot = Instance(Plot) tmp_plot = Instance(Plot) findpeaks = Button next_img = Button prev_img = Button peak_width = Range(low=2, high=200, value=10) tab_selected = Event ShowCC = Bool img_container = Instance(Component) container = Instance(Component) colorbar= Instance(Component) numpeaks_total = Int(0,cols=5) numpeaks_img = Int(0,cols=5) OK_custom=OK_custom_handler cbar_selection = Instance(RangeSelection) cbar_selected = Event thresh=Trait(None,None,List,Tuple,Array) thresh_upper=Float(1.0) thresh_lower=Float(0.0) numfiles=Int(1) img_idx=Int(0) tmp_img_idx=Int(0) csr=Instance(BaseCursorTool) traits_view = View( Group( Group( Item("img_container",editor=ComponentEditor(), show_label=False), HGroup( Item("ShowCC", editor=BooleanEditor(), label="Show cross correlation image"), Spring(), Item("prev_img",editor=ButtonEditor(label="<"),show_label=False, enabled_when='numfiles > 1'), Item("next_img",editor=ButtonEditor(label=">"),show_label=False, enabled_when='numfiles > 1'), ), label="Original image", show_border=True, trait_modified="tab_selected", orientation='vertical',), VGroup( Group( HGroup( Item("left", label="Left coordinate", style="custom"), Spring(), Item("top", label="Top coordinate", style="custom"), ), Item("tmp_size", label="Template size", style="custom"), Item("tmp_plot",editor=ComponentEditor(height=256, width=256), show_label=False, resizable=True), label="Template", show_border=True), Group( HGroup( Item("peak_width", label="Peak width", style="custom"), Spring(), Item("findpeaks",editor=ButtonEditor(label="Find Peaks"),show_label=False), ), HGroup( Item("thresh_lower",label="Threshold Lower Value", editor=TextEditor(evaluate=float, format_str='%1.4f')), Spring(), Item("thresh_upper",label="Threshold Upper Value", editor=TextEditor(evaluate=float, format_str='%1.4f')), ), HGroup( Item("numpeaks_img",label="Number of Cells selected (this image)",style='readonly'), Spring(), Item("numpeaks_total",label="Total",style='readonly'), ), label="Peak parameters", show_border=True), ), orientation='horizontal'), buttons = [ Action(name='OK', enabled_when = 'numpeaks_total > 0' ), CancelButton ], title="Template Picker", handler=OK_custom, kind='livemodal', key_bindings = key_bindings, width=940, height=530,resizable=True) def __init__(self, signal_instance): super(TemplatePicker, self).__init__() try: import cv except: try: import cv2.cv as cv except: print "OpenCV unavailable. Can't do cross correlation without it. Aborting." return None self.OK_custom=OK_custom_handler() self.sig=signal_instance if not hasattr(self.sig.mapped_parameters,"original_files"): self.titles=[os.path.splitext(self.sig.mapped_parameters.title)[0]] else: self.numfiles=len(self.sig.mapped_parameters.original_files.keys()) self.titles=self.sig.mapped_parameters.original_files.keys() tmp_plot_data=ArrayPlotData(imagedata=self.sig.data[self.img_idx,self.top:self.top+self.tmp_size,self.left:self.left+self.tmp_size]) tmp_plot=Plot(tmp_plot_data,default_origin="top left") tmp_plot.img_plot("imagedata", colormap=jet) tmp_plot.aspect_ratio=1.0 self.tmp_plot=tmp_plot self.tmp_plotdata=tmp_plot_data self.img_plotdata=ArrayPlotData(imagedata=self.sig.data[self.img_idx,:,:]) self.img_container=self._image_plot_container() self.crop_sig=None def render_image(self): plot = Plot(self.img_plotdata,default_origin="top left") img=plot.img_plot("imagedata", colormap=gray)[0] plot.title="%s of %s: "%(self.img_idx+1,self.numfiles)+self.titles[self.img_idx] plot.aspect_ratio=float(self.sig.data.shape[2])/float(self.sig.data.shape[1]) csr = CursorTool(img, drag_button='left', color='white', line_width=2.0) self.csr=csr csr.current_position=self.left, self.top img.overlays.append(csr) # attach the rectangle tool plot.tools.append(PanTool(plot,drag_button="right")) zoom = ZoomTool(plot, tool_mode="box", always_on=False, aspect_ratio=plot.aspect_ratio) plot.overlays.append(zoom) self.img_plot=plot return plot def render_scatplot(self): peakdata=ArrayPlotData() peakdata.set_data("index",self.peaks[self.img_idx][:,0]) peakdata.set_data("value",self.peaks[self.img_idx][:,1]) peakdata.set_data("color",self.peaks[self.img_idx][:,2]) scatplot=Plot(peakdata,aspect_ratio=self.img_plot.aspect_ratio,default_origin="top left") scatplot.plot(("index", "value", "color"), type="cmap_scatter", name="my_plot", color_mapper=jet(DataRange1D(low = 0.0, high = 1.0)), marker = "circle", fill_alpha = 0.5, marker_size = 6, ) scatplot.x_grid.visible = False scatplot.y_grid.visible = False scatplot.range2d=self.img_plot.range2d self.scatplot=scatplot self.peakdata=peakdata return scatplot def _image_plot_container(self): plot = self.render_image() # Create a container to position the plot and the colorbar side-by-side self.container=OverlayPlotContainer() self.container.add(plot) self.img_container = HPlotContainer(use_backbuffer = False) self.img_container.add(self.container) self.img_container.bgcolor = "white" if self.numpeaks_img>0: scatplot = self.render_scatplot() self.container.add(scatplot) colorbar = self.draw_colorbar() self.img_container.add(colorbar) return self.img_container def draw_colorbar(self): scatplot=self.scatplot cmap_renderer = scatplot.plots["my_plot"][0] selection = ColormappedSelectionOverlay(cmap_renderer, fade_alpha=0.35, selection_type="range") cmap_renderer.overlays.append(selection) if self.thresh is not None: cmap_renderer.color_data.metadata['selections']=self.thresh cmap_renderer.color_data.metadata_changed={'selections':self.thresh} # Create the colorbar, handing in the appropriate range and colormap colormap=scatplot.color_mapper colorbar = ColorBar(index_mapper=LinearMapper(range=DataRange1D(low = 0.0, high = 1.0)), orientation='v', resizable='v', width=30, padding=20) colorbar_selection=RangeSelection(component=colorbar) colorbar.tools.append(colorbar_selection) ovr=colorbar.overlays.append(RangeSelectionOverlay(component=colorbar, border_color="white", alpha=0.8, fill_color="lightgray", metadata_name='selections')) #ipshell('colorbar, colorbar_selection and ovr available:') self.cbar_selection=colorbar_selection self.cmap_renderer=cmap_renderer colorbar.plot = cmap_renderer colorbar.padding_top = scatplot.padding_top colorbar.padding_bottom = scatplot.padding_bottom self.colorbar=colorbar return colorbar @on_trait_change('ShowCC') def toggle_cc_view(self): if self.ShowCC: self.update_CC() grid_data_source = self.img_plot.range2d.sources[0] grid_data_source.set_data(np.arange(self.CC.shape[1]), np.arange(self.CC.shape[0])) else: self.img_plotdata.set_data("imagedata",self.sig.data[self.img_idx,:,:]) self.redraw_plots() @on_trait_change("img_idx") def update_img_depth(self): if self.ShowCC: self.update_CC() else: self.img_plotdata.set_data("imagedata",self.sig.data[self.img_idx,:,:]) self.img_plot.title="%s of %s: "%(self.img_idx+1,self.numfiles)+self.titles[self.img_idx] self.redraw_plots() def _get_max_pos_x(self): max_pos_x=self.sig.data.shape[-1]-self.tmp_size-1 if max_pos_x>0: return max_pos_x else: return None def _get_max_pos_y(self): max_pos_y=self.sig.data.shape[-2]-self.tmp_size-1 if max_pos_y>0: return max_pos_y else: return None @on_trait_change('next_img') def increase_img_idx(self,info): if self.img_idx==(self.numfiles-1): self.img_idx=0 else: self.img_idx+=1 @on_trait_change('prev_img') def decrease_img_idx(self,info): if self.img_idx==0: self.img_idx=self.numfiles-1 else: self.img_idx-=1 @on_trait_change('left, top') def update_csr_position(self): if self.left>0: self.csr.current_position=self.left,self.top @on_trait_change('csr:current_position') def update_top_left(self): if self.csr.current_position[0]>0 or self.csr.current_position[1]>0: if self.csr.current_position[0]>self.max_pos_x: if self.csr.current_position[1]<self.max_pos_y: self.top=self.csr.current_position[1] else: self.csr.current_position=self.max_pos_x, self.max_pos_y elif self.csr.current_position[1]>self.max_pos_y: self.left,self.top=self.csr.current_position[0],self.max_pos_y else: self.left,self.top=self.csr.current_position @on_trait_change('left, top, tmp_size') def update_tmp_plot(self): self.tmp_plotdata.set_data("imagedata", self.sig.data[self.img_idx,self.top:self.top+self.tmp_size,self.left:self.left+self.tmp_size]) grid_data_source = self.tmp_plot.range2d.sources[0] grid_data_source.set_data(np.arange(self.tmp_size), np.arange(self.tmp_size)) self.tmp_img_idx=self.img_idx if self.numpeaks_total>0: print "clearing peaks" self.peaks=[np.array([[0,0,-1]])] return @on_trait_change('left, top, tmp_size') def update_CC(self): if self.ShowCC: self.CC = cv_funcs.xcorr(self.sig.data[self.tmp_img_idx,self.top:self.top+self.tmp_size, self.left:self.left+self.tmp_size], self.sig.data[self.img_idx,:,:]) self.img_plotdata.set_data("imagedata",self.CC) @on_trait_change('cbar_selection:selection') def update_thresh(self): try: thresh=self.cbar_selection.selection self.thresh=thresh self.cmap_renderer.color_data.metadata['selections']=thresh self.thresh_lower=thresh[0] self.thresh_upper=thresh[1] #cmap_renderer.color_data.metadata['selection_masks']=self.thresh self.cmap_renderer.color_data.metadata_changed={'selections':thresh} self.container.request_redraw() self.img_container.request_redraw() except: pass @on_trait_change('thresh_upper,thresh_lower') def manual_thresh_update(self): self.thresh=[self.thresh_lower,self.thresh_upper] self.cmap_renderer.color_data.metadata['selections']=self.thresh self.cmap_renderer.color_data.metadata_changed={'selections':self.thresh} self.container.request_redraw() self.img_container.request_redraw() @on_trait_change('peaks,cbar_selection:selection,img_idx') def calc_numpeaks(self): try: thresh=self.cbar_selection.selection self.thresh=thresh except: thresh=[] if thresh==[] or thresh==() or thresh==None: thresh=(0,1) self.numpeaks_total=int(np.sum([np.sum(np.ma.masked_inside(self.peaks[i][:,2],thresh[0],thresh[1]).mask) for i in xrange(len(self.peaks))])) try: self.numpeaks_img=int(np.sum(np.ma.masked_inside(self.peaks[self.img_idx][:,2],thresh[0],thresh[1]).mask)) except: self.numpeaks_img=0 @on_trait_change('findpeaks') def locate_peaks(self): from hyperspy import peak_char as pc peaks=[] """from hyperspy.misc.progressbar import ProgressBar, \ Percentage, RotatingMarker, ETA, Bar widgets = ['Locating peaks: ', Percentage(), ' ', Bar(marker=RotatingMarker()), ' ', ETA()] pbar = ProgressBar(widgets=widgets, maxval=100).start()""" progress = ProgressDialog(title="Peak finder progress", message="Finding peaks on %s images"%self.numfiles, max=self.numfiles, show_time=True, can_cancel=False) progress.open() for idx in xrange(self.numfiles): #pbar.update(float(idx)/self.numfiles*100) self.CC = cv_funcs.xcorr(self.sig.data[self.tmp_img_idx, self.top:self.top+self.tmp_size, self.left:self.left+self.tmp_size], self.sig.data[idx,:,:]) # peak finder needs peaks greater than 1. Multiply by 255 to scale them. pks=pc.two_dim_findpeaks(self.CC*255, peak_width=self.peak_width, medfilt_radius=None) pks[:,2]=pks[:,2]/255. peaks.append(pks) progress.update(idx+1) #pbar.finish() self.peaks=peaks def mask_peaks(self,idx): thresh=self.cbar_selection.selection if thresh==[]: thresh=(0,1) mpeaks=np.ma.asarray(self.peaks[idx]) mpeaks[:,2]=np.ma.masked_outside(mpeaks[:,2],thresh[0],thresh[1]) return mpeaks @on_trait_change("peaks") def redraw_plots(self): oldplot=self.img_plot self.container.remove(oldplot) newplot=self.render_image() self.container.add(newplot) self.img_plot=newplot try: # if these haven't been created before, this will fail. wrap in try to prevent that. oldscat=self.scatplot self.container.remove(oldscat) oldcolorbar = self.colorbar self.img_container.remove(oldcolorbar) except: pass if self.numpeaks_img>0: newscat=self.render_scatplot() self.container.add(newscat) self.scatplot=newscat colorbar = self.draw_colorbar() self.img_container.add(colorbar) self.colorbar=colorbar self.container.request_redraw() self.img_container.request_redraw() def crop_cells_stack(self): from hyperspy.signals.aggregate import AggregateCells if self.numfiles==1: self.crop_sig=self.crop_cells() return else: crop_agg=[] for idx in xrange(self.numfiles): peaks=np.ma.compress_rows(self.mask_peaks(idx)) if peaks.any(): crop_agg.append(self.crop_cells(idx)) self.crop_sig=AggregateCells(*crop_agg) return def crop_cells(self,idx=0): print "cropping cells..." from hyperspy.signals.image import Image # filter the peaks that are outside the selected threshold peaks=np.ma.compress_rows(self.mask_peaks(idx)) tmp_sz=self.tmp_size data=np.zeros((peaks.shape[0],tmp_sz,tmp_sz)) if not hasattr(self.sig.mapped_parameters,"original_files"): parent=self.sig else: parent=self.sig.mapped_parameters.original_files[self.titles[idx]] pmp=parent.mapped_parameters positions=np.zeros((peaks.shape[0],1),dtype=[('filename','a256'),('id','i4'),('position','f4',(1,2))]) for i in xrange(peaks.shape[0]): # crop the cells from the given locations data[i,:,:]=self.sig.data[idx,peaks[i,1]:peaks[i,1]+tmp_sz,peaks[i,0]:peaks[i,0]+tmp_sz] positions[i]=(self.titles[idx],i,peaks[i,:2]) crop_sig=Image({'data':data, 'mapped_parameters':{ 'title':'Cropped cells from %s'%self.titles[idx], 'record_by':'image', 'locations':positions, 'original_files':{pmp.title:parent}, } }) return crop_sig
class StackedPlot(ChacoPlot): offset = Range(0.0, 1.0, 0.015) value_range = Range(0.01, 1.05, 1.00) flip_order = Bool(False) def _get_traits_group(self): return VGroup( HGroup( Item('flip_order'), Item('offset'), Item('value_range'), ), UItem('component', editor=ComponentEditor()), ) def __init__(self): super(StackedPlot, self).__init__() self.container = OverlayPlotContainer(bgcolor='white', use_backbuffer=True, border_visible=True, padding=50, padding_left=110, fill_padding=True ) self.data = ArrayPlotData() self.chaco_plot = None self.value_mapper = None self.index_mapper = None self.x_axis = PlotAxis(component=self.container, orientation='bottom', title=u'Angle (2\u0398)', title_font=settings.axis_title_font, tick_label_font=settings.tick_font) y_axis_title = 'Normalized intensity (%s)' % get_value_scale_label('linear') self.y_axis = PlotAxis(component=self.container, orientation='left', title=y_axis_title, title_font=settings.axis_title_font, tick_label_font=settings.tick_font) self.container.overlays.extend([self.x_axis, self.y_axis]) self.container.tools.append( TraitsTool(self.container, classes=[LinePlot,PlotAxis])) self.colors = [] self.last_flip_order = self.flip_order @on_trait_change('offset, value_range, flip_order') def _replot_data(self): self._plot(self.data_x, None, self.data_z, self.scale) self.container.request_redraw() def _prepare_data(self, datasets): interpolate = True stack = stack_datasets(datasets) if interpolate: (x, z) = interpolate_datasets(stack, points=4800) x = array([x] * len(datasets)) else: x, z = map(np.transpose, np.transpose(stack)) return x, None, z def _plot(self, x, y, z, scale): self.data_x, self.data_z, self.scale = x, z, scale if self.container.components: self.colors = map(lambda plot: plot.color, self.container.components) if self.last_flip_order != self.flip_order: self.colors.reverse() self.container.remove(*self.container.components) # Use a custom renderer so plot lines are clickable self.chaco_plot = Plot(self.data, renderer_map={ 'line': ClickableLinePlot }) self.chaco_plot.bgcolor = 'white' self.value_mapper = None self.index_mapper = None if len(self.data_x) == len(self.colors): colors = self.colors[:] else: colors = ['black'] * len(self.data_x) if self.flip_order: z = z[::-1] spacing = (z.max(axis=1) - z.min(axis=1)).min() * self.value_range offset = spacing * self.offset for i, (x_row, z_row) in enumerate(zip(x, z)): self.data.set_data('data_x_' + str(i), x_row) self.data.set_data('data_y_offset_' + str(i), z_row * self.value_range + offset * i) plots = self.chaco_plot.plot(('data_x_' + str(i), 'data_y_offset_' + str(i)), color=colors[i], type='line') plot = plots[0] self.container.add(plot) # Required for double-clicking plots plot.index.sort_order = 'ascending' plot.value.sort_order = 'ascending' if self.value_mapper is None: self.index_mapper = plot.index_mapper self.value_mapper = plot.value_mapper else: plot.value_mapper = self.value_mapper self.value_mapper.range.add(plot.value) plot.index_mapper = self.index_mapper self.index_mapper.range.add(plot.index) range = self.value_mapper.range range.high = (range.high - range.low) * self.value_range + range.low self.x_axis.mapper = self.index_mapper self.y_axis.mapper = self.value_mapper self.y_axis.title = 'Normalized intensity (%s)' % \ get_value_scale_label(scale) self.zoom_tool = ClickUndoZoomTool( plot, tool_mode="box", always_on=True, pointer="cross", drag_button=settings.zoom_button, undo_button=settings.undo_button, ) plot.overlays.append(self.zoom_tool) self.last_flip_order = self.flip_order return self.container def _reset_view(self): self.zoom_tool.revert_history_all()