コード例 #1
0
class HistogramPlotHandler(HasTraits):
    '''
    Class for handling the histograms.
    '''

    # Index for the histogram plot
    index = Array

    # The selection handler object for the selected data
    selection_handler = Instance(SelectionHandler)

    # OVerlayPlotContainer for the histogram plot
    container = Instance(OverlayPlotContainer)

    # Number of bins of the histogram
    nbins = Int(10)

    # Whether the data is a pandas dataframe or a numpy array
    AS_PANDAS_DATAFRAME = Bool

    def __init__(self):
        self.index = range(self.nbins)
        self.selection_handler = SelectionHandler()
        self.container = OverlayPlotContainer()

    def draw_histogram(self):
        '''
        Default function called when drawing the histogram.
        '''
        for component in self.container.components:
            self.container.remove(component)

        self.selection_handler.create_selection()

        if len(self.selection_handler.selected_indices) == 1:
            tuple_list = self.selection_handler.selected_indices[0]
            if self.AS_PANDAS_DATAFRAME:
                column_name = self.data.columns[tuple_list[1]]
                y = self.data[column_name]
                self.index = np.arange(self.nbins)
                hist = np.histogram(y, self.nbins)[0]
                plotdata = ArrayPlotData(x=self.index, y=hist)
                plot = Plot(plotdata)
                plot.plot(("x", "y"), type='bar', bar_width=0.5)
                self.container.add(plot)
            else:
                column = tuple_list[1]
                y = self.data[:, column]
                self.index = np.arange(self.nbins)
                hist = np.histogram(y, self.nbins)[0]
                plotdata = ArrayPlotData(x=self.index, y=hist)
                plot = Plot(plotdata)
                plot.plot(("x", "y"), type='bar', bar_width=0.5)
                self.container.add(plot)

            self.container.request_redraw()

        self.selection_handler.flush()
コード例 #2
0
ファイル: plot_handlers.py プロジェクト: jilott/enaml-csv
class HistogramPlotHandler(HasTraits):
    """
    Class for handling the histograms.
    """

    # Index for the histogram plot
    index = Array

    # The selection handler object for the selected data
    selection_handler = Instance(SelectionHandler)

    # OVerlayPlotContainer for the histogram plot
    container = Instance(OverlayPlotContainer)

    # Number of bins of the histogram
    nbins = Int(10)

    # Whether the data is a pandas dataframe or a numpy array
    AS_PANDAS_DATAFRAME = Bool

    def __init__(self):
        self.index = range(self.nbins)
        self.selection_handler = SelectionHandler()
        self.container = OverlayPlotContainer()

    def draw_histogram(self):
        """
        Default function called when drawing the histogram.
        """
        for component in self.container.components:
            self.container.remove(component)

        self.selection_handler.create_selection()

        if len(self.selection_handler.selected_indices) == 1:
            tuple_list = self.selection_handler.selected_indices[0]
            if self.AS_PANDAS_DATAFRAME:
                column_name = self.data.columns[tuple_list[1]]
                y = self.data[column_name]
                self.index = np.arange(self.nbins)
                hist = np.histogram(y, self.nbins)[0]
                plotdata = ArrayPlotData(x=self.index, y=hist)
                plot = Plot(plotdata)
                plot.plot(("x", "y"), type="bar", bar_width=0.5)
                self.container.add(plot)
            else:
                column = tuple_list[1]
                y = self.data[:, column]
                self.index = np.arange(self.nbins)
                hist = np.histogram(y, self.nbins)[0]
                plotdata = ArrayPlotData(x=self.index, y=hist)
                plot = Plot(plotdata)
                plot.plot(("x", "y"), type="bar", bar_width=0.5)
                self.container.add(plot)

            self.container.request_redraw()

        self.selection_handler.flush()
コード例 #3
0
ファイル: ucc.py プロジェクト: magnunor/analyzarr
class CellCropper(StackViewer):
    template = Array
    CC = Array
    peaks = List
    zero=Int(0)
    tmp_size = Range(low=2, high=512, value=64, cols=4)
    max_pos_x=Property(depends_on=['tmp_size'])
    max_pos_y=Property(depends_on=['tmp_size'])
    top = Range(low='zero',high='max_pos_y', value=20, cols=4)
    left = Range(low='zero',high='max_pos_x', value=20, cols=4)
    is_square = Bool
    tmp_plot = Instance(Plot)
    findpeaks = Button
    peak_width = Range(low=2, high=200, value=10)
    ShowCC = Bool
    numpeaks_total = Int(0,cols=5)
    numpeaks_img = Int(0,cols=5)
    OK_custom=OK_custom_handler
    thresh=Trait(None,None,List,Tuple,Array)
    thresh_upper=Float(1.0)
    thresh_lower=Float(-1.0)
    tmp_img_idx=Int(0)

    csr=Instance(BaseCursorTool)

    traits_view = View(
        Group(
            Group(
                Item("img_container",editor=ComponentEditor(), show_label=False),
                HGroup(
                    Item("ShowCC", editor=BooleanEditor(), label="Show cross correlation image"),
                    Spring(),
                    Item("prev_img",editor=ButtonEditor(label="<"),show_label=False, enabled_when='numfiles > 1'),
                    Item("next_img",editor=ButtonEditor(label=">"),show_label=False, enabled_when='numfiles > 1'),
                    ),
                label="Original image", show_border=True, trait_modified="tab_selected",
                orientation='vertical',),
            VGroup(
                Group(
                    HGroup(
                        Item("left", label="Left coordinate", style="custom"),
                        Spring(),
                        Item("top", label="Top coordinate", style="custom"),
                        ),
                    Item("tmp_size", label="Template size", style="custom"),
                    Item("tmp_plot",editor=ComponentEditor(height=256, width=256), show_label=False, resizable=True),
                    label="Template", show_border=True),
                Group(
                    HGroup(
                        Item("peak_width", label="Peak width", style="custom"),
                        Spring(),
                        Item("findpeaks",editor=ButtonEditor(label="Find Peaks"),show_label=False),
                        ),
                    HGroup(
                        Item("thresh_lower",label="Threshold Lower Value", editor=TextEditor(evaluate=float,
                                                                                             format_str='%1.4f')),
                        Spring(),
                        Item("thresh_upper",label="Threshold Upper Value", editor=TextEditor(evaluate=float,
                                                                                             format_str='%1.4f')),
                    ),
                    HGroup(
                        Item("numpeaks_img",label="Number of Cells selected (this image)",style='readonly'),
                        Spring(),
                        Item("numpeaks_total",label="Total",style='readonly'),                          
                        ),
                    label="Peak parameters", show_border=True),
                ),
            orientation='horizontal'),
        buttons = [ Action(name='OK', enabled_when = 'numpeaks_total > 0' ),
            CancelButton ],
        title="Template Picker",
        handler=OK_custom, kind='livemodal',
        key_bindings = key_bindings,
        width=940, height=530,resizable=True)        

    def __init__(self, controller, *args, **kw):
        super(CellCropper, self).__init__(controller, *args, **kw)
        try:
            import cv
        except:
            try:
                import cv2.cv as cv
            except:
                print "OpenCV unavailable.  Can't do cross correlation without it.  Aborting."
                return None
        self.OK_custom=OK_custom_handler()
        self.template = self.data[self.top:self.top+self.tmp_size,self.left:self.left+self.tmp_size]
        tmp_plot_data=ArrayPlotData(imagedata=self.template)
        tmp_plot=Plot(tmp_plot_data,default_origin="top left")
        tmp_plot.img_plot("imagedata", colormap=jet)
        tmp_plot.aspect_ratio=1.0
        self.tmp_plot=tmp_plot
        self.tmp_plotdata=tmp_plot_data
        self.crop_sig=None

    def render_image(self):
        plot = super(CellCropper,self).render_image()
        img=plot.img_plot("imagedata", colormap=gray)[0]
        csr = CursorTool(img, drag_button='left', color='white',
                         line_width=2.0)
        self.csr=csr
        csr.current_position=self.left, self.top
        img.overlays.append(csr)
        self.img_plot=plot
        return plot

    # TODO: use base class render_scatter_overlay method
    def render_scatplot(self):
        peakdata=ArrayPlotData()
        peakdata.set_data("index",self.peaks[self.img_idx][:,0])
        peakdata.set_data("value",self.peaks[self.img_idx][:,1])
        peakdata.set_data("color",self.peaks[self.img_idx][:,2])
        scatplot=Plot(peakdata,aspect_ratio=self.img_plot.aspect_ratio,default_origin="top left")
        scatplot.plot(("index", "value", "color"),
                      type="cmap_scatter",
                      name="my_plot",
                      color_mapper=jet(DataRange1D(low = 0.0,
                                       high = 1.0)),
                      marker = "circle",
                      fill_alpha = 0.5,
                      marker_size = 6,
                      )
        scatplot.x_grid.visible = False
        scatplot.y_grid.visible = False
        scatplot.range2d=self.img_plot.range2d
        self.scatplot=scatplot
        self.peakdata=peakdata
        return scatplot

    def _image_plot_container(self):
        plot = self.render_image()

        # Create a container to position the plot and the colorbar side-by-side
        self.container=OverlayPlotContainer()
        self.container.add(plot)
        self.img_container = HPlotContainer(use_backbuffer = False)
        self.img_container.add(self.container)
        self.img_container.bgcolor = "white"
        #ipdb.set_trace()

        if self.numpeaks_img>0:
            scatplot = self.render_scatplot()
            self.container.add(scatplot)
            colorbar = self.draw_colorbar()
            self.img_container.add(colorbar)
        return self.img_container

    def draw_colorbar(self):
        scatplot=self.scatplot
        cmap_renderer = scatplot.plots["my_plot"][0]
        selection = ColormappedSelectionOverlay(cmap_renderer, fade_alpha=0.35, 
                                                selection_type="range")
        cmap_renderer.overlays.append(selection)
        if self.thresh is not None:
            cmap_renderer.color_data.metadata['selections']=self.thresh
            cmap_renderer.color_data.metadata_changed={'selections':self.thresh}
        # Create the colorbar, handing in the appropriate range and colormap
        colormap=scatplot.color_mapper
        colorbar = ColorBar(
            index_mapper=LinearMapper(
                range=DataRange1D(
                    low = -1.0,
                    high = 1.0)
                ),
            orientation='v',
            resizable='v',
            width=30,
            padding=20)
        colorbar_selection=RangeSelection(component=colorbar)
        colorbar.tools.append(colorbar_selection)
        ovr=colorbar.overlays.append(
            RangeSelectionOverlay(component=colorbar,
                                  border_color="white",
                                  alpha=0.8,
                                  fill_color="lightgray",
                                  metadata_name='selections'))
        #ipshell('colorbar, colorbar_selection and ovr available:')
        self.cbar_selection=colorbar_selection
        self.cmap_renderer=cmap_renderer
        colorbar.plot = cmap_renderer
        colorbar.padding_top = scatplot.padding_top
        colorbar.padding_bottom = scatplot.padding_bottom
        self.colorbar=colorbar
        return colorbar

    @on_trait_change('ShowCC')
    def toggle_cc_view(self):
        if self.ShowCC:
            self.update_CC()
            grid_data_source = self.img_plot.range2d.sources[0]
            grid_data_source.set_data(np.arange(self.CC.shape[1]), 
                                      np.arange(self.CC.shape[0]))
        else:
            self.img_plotdata.set_data("imagedata",self.data)
        self.redraw_plots()

    @on_trait_change("img_idx")
    def update_img_depth(self):
        """
        TODO: We look up the index in the model - first get a list of files, then get the name
        of the file at the given index.
        """
        super(CellCropper, self).update_img_depth()
        if self.ShowCC:
            self.update_CC()
        self.redraw_plots()        

    def _get_max_pos_x(self):
        max_pos_x=self.data.shape[-1]-self.tmp_size-1
        if max_pos_x>0:
            return max_pos_x
        else:
            return None

    def _get_max_pos_y(self):
        max_pos_y=self.data.shape[-2]-self.tmp_size-1
        if max_pos_y>0:
            return max_pos_y
        else:
            return None

    @on_trait_change('left, top')
    def update_csr_position(self):
        if self.left>0:        
            self.csr.current_position=self.left,self.top

    @on_trait_change('csr:current_position')
    def update_top_left(self):
        if self.csr.current_position[0]>0 or self.csr.current_position[1]>0:
            if self.csr.current_position[0]>self.max_pos_x:
                if self.csr.current_position[1]<self.max_pos_y:
                    self.top=self.csr.current_position[1]
                else:
                    self.csr.current_position=self.max_pos_x, self.max_pos_y
            elif self.csr.current_position[1]>self.max_pos_y:
                self.left,self.top=self.csr.current_position[0],self.max_pos_y
            else:
                self.left,self.top=self.csr.current_position
        
    @on_trait_change('left, top, tmp_size')
    def update_tmp_plot(self):
        self.template = self.data[self.top:self.top+self.tmp_size,self.left:self.left+self.tmp_size]
        self.tmp_plotdata.set_data("imagedata", self.template)
        grid_data_source = self.tmp_plot.range2d.sources[0]
        grid_data_source.set_data(np.arange(self.tmp_size), np.arange(self.tmp_size))
        self.tmp_img_idx=self.img_idx
        if self.numpeaks_total>0:
            print "clearing peaks"
            self.peaks=[np.array([[0,0,-1]])]
        self.update_CC()
        return

    def update_CC(self):
        if self.ShowCC:
            self.CC = cv_funcs.xcorr(self.template, self.data)
            self.img_plotdata.set_data("imagedata",self.CC)

    @on_trait_change('cbar_selection:selection')
    def update_thresh(self):
        try:
            thresh=self.cbar_selection.selection
            self.thresh=thresh
            self.cmap_renderer.color_data.metadata['selections']=thresh
            self.thresh_lower=thresh[0]
            self.thresh_upper=thresh[1]
            #cmap_renderer.color_data.metadata['selection_masks']=self.thresh
            self.cmap_renderer.color_data.metadata_changed={'selections':thresh}
            self.container.request_redraw()
            self.img_container.request_redraw()
        except:
            pass

    @on_trait_change('thresh_upper,thresh_lower')
    def manual_thresh_update(self):
        self.thresh=[self.thresh_lower,self.thresh_upper]
        self.cmap_renderer.color_data.metadata['selections']=self.thresh
        self.cmap_renderer.color_data.metadata_changed={'selections':self.thresh}
        self.container.request_redraw()
        self.img_container.request_redraw()

    @on_trait_change('peaks,cbar_selection:selection,img_idx')
    def calc_numpeaks(self):
        try:
            thresh=self.cbar_selection.selection
            self.thresh=thresh
        except:
            thresh=[]
        if thresh==[] or thresh==() or thresh==None:
            thresh=(-1,1)
        self.numpeaks_total=int(np.sum([np.sum(np.ma.masked_inside(self.peaks[i][:,2],thresh[0],thresh[1]).mask) for i in xrange(len(self.peaks))]))
        try:
            self.numpeaks_img=int(np.sum(np.ma.masked_inside(self.peaks[self.img_idx][:,2],thresh[0],thresh[1]).mask))
        except:
            self.numpeaks_img=0

    @on_trait_change('findpeaks')
    def locate_peaks(self):
        peaks=[]
        progress = ProgressDialog(title="Peak finder progress", message="Finding peaks on %s images"%self.numfiles, max=self.numfiles, show_time=True, can_cancel=False)
        progress.open()
        for idx in xrange(self.numfiles):
            self.controller.set_active_index(idx)
            self.data = self.controller.get_active_image()[:]
            self.CC = cv_funcs.xcorr(self.template, self.data)
            # peak finder needs peaks greater than 1.  Multiply by 255 to scale them.
            pks=pc.two_dim_findpeaks(self.CC*255, peak_width=self.peak_width, medfilt_radius=None)
            pks[:,2]=pks[:,2]/255.
            peaks.append(pks)
            progress.update(idx+1)
        #ipdb.set_trace()
        self.peaks=peaks
        self.redraw_plots()
        
    def mask_peaks(self,idx):
        thresh=self.cbar_selection.selection
        if thresh==[]:
            thresh=(-1,1)
        mpeaks=np.ma.asarray(self.peaks[idx])
        mpeaks[:,2]=np.ma.masked_outside(mpeaks[:,2],thresh[0],thresh[1])
        return mpeaks

    @on_trait_change("peaks")
    def redraw_plots(self):
        oldplot=self.img_plot
        self.container.remove(oldplot)
        newplot=self.render_image()
        self.container.add(newplot)
        self.img_plot=newplot

        try:
            # if these haven't been created before, this will fail.  wrap in try to prevent that.
            oldscat=self.scatplot
            self.container.remove(oldscat)
            oldcolorbar = self.colorbar
            self.img_container.remove(oldcolorbar)
        except:
            pass

        if self.numpeaks_img>0:
            newscat=self.render_scatplot()
            self.container.add(newscat)
            self.scatplot=newscat
            colorbar = self.draw_colorbar()
            self.img_container.add(colorbar)
            self.colorbar=colorbar

        self.container.request_redraw()
        self.img_container.request_redraw()

    def crop_cells(self):
        print "cropping cells..."
        for idx in xrange(self.numfiles):
            # filter the peaks that are outside the selected threshold
            self.controller.set_active_index(idx)
            self.data = self.controller.get_active_image()
            self.name = self.controller.get_active_name()
            peaks=np.ma.compress_rows(self.mask_peaks(idx))
            tmp_sz=self.tmp_size
            data=np.zeros((peaks.shape[0],tmp_sz,tmp_sz))
            if data.shape[0] >0:
                for i in xrange(peaks.shape[0]):
                    # crop the cells from the given locations
                    data[i,:,:]=self.data[peaks[i,1]:peaks[i,1]+tmp_sz, 
                                      peaks[i,0]:peaks[i,0]+tmp_sz]
                # send the data to the controller for storage in the chest
                self.controller.add_cells(name = self.name, data = data,
                                      locations = peaks)
コード例 #4
0
ファイル: plot_handlers.py プロジェクト: jilott/enaml-csv
class KMeansPlotHandler(HasTraits):
    """
    Class for plotting the k-means clusters.
    
    """

    # the data to cluster
    data = Array

    # the dataset created after preprocessing
    dataset = Array

    # the sklearn.cluster.KMeans object
    kmeans = Instance(KMeans)

    # Number of clusters
    n_clusters = Int

    # Maximum iterations for the clustering algorithm
    max_iter = Int

    # Container for the cluster plots
    container = Instance(OverlayPlotContainer)

    # the columns from the dataset to omit when performing clustering
    to_omit = List

    def __init__(self):
        self.kmeans = KMeans()
        self.container = OverlayPlotContainer()

    def create_dataset(self):
        """
        Creates a numpy array from the current selection to pass to the 
        sklearn.cluster.kmeans object.
        """

        if self.to_omit:
            if len(self.to_omit) > 0:
                n_rows = self.data.shape[0]
                n_cols = self.data.shape[1]
                to_omit = []
                for elem in self.to_omit:
                    if elem.isdigit():
                        to_omit.append(int(elem))

                dataset = self.data[:, 0].reshape((n_rows, 1))

                for elem in range(n_cols):
                    if elem not in to_omit:
                        if elem > 0:
                            dataset = np.hstack((dataset, self.data[:, elem].reshape((n_rows, 1))))
            self.dataset = dataset

    def plot_clusters(self):
        """
        Plots the clusters after calling the .fit method of the sklearn kmeans 
        estimator.
        """

        self.kmeans.n_clusters = self.n_clusters
        self.kmeans.fit(self.dataset)

        # Reducing dimensions of the dataset and the cluster centers for
        # plottting
        pca = PCA(n_components=2, whiten=True)
        cluster_centers = pca.fit_transform(self.kmeans.cluster_centers_)
        dataset_red = pca.fit_transform(self.dataset)

        removed_components = []
        for component in self.container.components:
            removed_components.append(component)

        for component in removed_components:
            self.container.remove(component)

        for i in range(self.n_clusters):

            current_indices = find(self.kmeans.labels_ == i)
            current_data = dataset_red[current_indices, :]

            plotdata = ArrayPlotData(x=current_data[:, 0], y=current_data[:, 1])
            plot = Plot(plotdata)
            plot.plot(("x", "y"), type="scatter", color=tuple(COLOR_PALETTE[i]))
            self.container.add(plot)

        plotdata_cent = ArrayPlotData(x=cluster_centers[:, 0], y=cluster_centers[:, 1])
        plot_cent = Plot(plotdata_cent)
        plot_cent.plot(("x", "y"), type="scatter", marker="cross", marker_size=8)
        self.container.add(plot_cent)

        self.container.request_redraw()
コード例 #5
0
ファイル: plot_handlers.py プロジェクト: jilott/enaml-csv
class XYPlotHandler(HasTraits):
    """
    Class for handling XY plots
    """

    # Whether the data is a pandas dataframe
    AS_PANDAS_DATAFRAME = Bool

    # The container for all current plots. Gets updated everytime a plot is
    # added.
    container = OverlayPlotContainer

    # This can be removed.
    plotdata = ArrayPlotData

    # The current Plot object.
    plot = Plot

    # ColorTrait, mainly required for the TraitsUIItem view.
    color = ColorTrait("blue")

    # Marker trait for the view
    marker = marker_trait

    # Marker size trait
    marker_size = Int(4)

    # An instance of SelectionHandler for adding plots from the current
    # selection.
    selection_handler = Instance(SelectionHandler)

    # Bool traits for checking the type of the plot (discrete / continuous)
    plot_type_disc = Bool
    plot_type_cont = Bool

    # The data from which to draw the plots, same as the table attribute of
    # CsvModel
    table = Array

    # The pandas data frame if AS_PANDAS_DATAFRAME
    data_frame = Instance(DataFrame)

    # Contains the grid underlays of all the current plots
    grid_underlays = List

    # Used for viewing the list of the plots and the legend
    plot_list_view = Dict

    # TraitsUI view for plot properties, yet to find an enaml implementation
    view = View(Item("color"), Item("marker"), Item("marker_size"))

    # Trait that defines whether tools are present.
    add_pan_tool = Bool
    add_zoom_tool = Bool
    add_dragzoom = Bool

    # Whether grids and axes are visible
    show_grid = Bool

    def __init__(self):
        self.selection_handler = SelectionHandler()
        self.container = OverlayPlotContainer()
        self.underlays = []
        self.add_pan_tool = False
        self.add_zoom_tool = False
        self.add_dragzoom = False
        self.show_grid = False

    def add_xyplot_selection(self, plot_name):
        """
        Called when the 'add plot from selection button is clicked.'
        """

        self.selection_handler.create_selection()
        if self.selection_handler.xyplot_check():

            if self.AS_PANDAS_DATAFRAME:
                x_column = self.data_frame.columns[self.selection_handler.selected_indices[0][1]]
                y_column = self.data_frame.columns[self.selection_handler.selected_indices[1][1]]
                x = np.array(self.data_frame[x_column])
                y = np.array(self.data_frame[y_column])
                self.plotdata = ArrayPlotData(x=x, y=y)

            else:

                first_column = self.selection_handler.selected_indices[0]
                second_column = self.selection_handler.selected_indices[1]
                self.plotdata = ArrayPlotData(x=self.table[:, first_column[1]], y=self.table[:, second_column[1]])

            plot = Plot(self.plotdata)

            if self.plot_type_disc:
                plot_type = "scatter"
            else:
                plot_type = "line"
            plot.plot(("x", "y"), type=plot_type, color=self.color, marker=self.marker, marker_size=self.marker_size)

            self.plot = plot

            for underlay in self.plot.underlays:
                if isinstance(underlay, PlotGrid):
                    if underlay not in self.grid_underlays:
                        self.grid_underlays.append(underlay)

            for underlay in self.grid_underlays:
                if underlay in self.plot.underlays:
                    self.plot.underlays.remove(underlay)

            if plot_name == "":
                self.plot_list_view["plot" + str(len(self.plot_list_view))] = self.plot
            else:
                self.plot_list_view[plot_name] = self.plot
            self.container.add(self.plot)

            self.container.request_redraw()

        self.selection_handler.flush()

    def grid_toggle(self, checked):
        """
        Called when the 'Show Grid' checkbox ins toggled
        """
        if not checked:
            for plot in self.container.components:
                for underlay in self.grid_underlays:
                    if underlay in plot.underlays:
                        plot.underlays.remove(underlay)
        else:
            for plot in self.container.components:
                for underlay in self.grid_underlays:
                    if underlay not in plot.underlays:
                        plot.underlays.append(underlay)
        self.container.request_redraw()

    def remove_selected_plots(self, selection):
        """
        Called when the 'Remove Selected Plots' button is clicked
        """
        remove_indices = []
        for model_index in selection:
            remove_indices.append(model_index[0].row)
        remove_plots = []
        for index in remove_indices:
            remove_plots.append(self.plot_list_view.keys()[index])

        removed_plots = []
        for plot in remove_plots:
            removed_plots.append(self.plot_list_view.pop(plot))
        for plot in self.container.components:
            self.container.remove(plot)
        for plot in self.plot_list_view.keys():
            self.container.add(self.plot_list_view[plot])
        self.container.request_redraw()

    def edit_selection(self, show_grid, plot_visible, plot_type_disc):
        """
        Called to start editing the selected plot. Should accompany the 'Edit
        Plot' dialog.
        """

        # self.selection_handler.create_selection()
        # index = self.selection_handler.selected_indices[0][0]
        # plot_name = self.plot_list_view.keys()[index]
        # plot = self.plot_list_view[plot_name]

        self.container.remove(self.plot)

        self.plot_type_disc = plot_type_disc

        if self.plot_type_disc:
            plot_type = "scatter"
        else:
            plot_type = "line"

        plot = Plot(self.plot.data)
        plot.plot(("x", "y"), color=self.color, type=plot_type, marker=self.marker, marker_size=self.marker_size)
        self.plot = plot
        self.plot.visible = plot_visible

        grid_underlays = []

        if not show_grid:
            for underlay in self.plot.underlays:
                if isinstance(underlay, PlotGrid):
                    grid_underlays.append(underlay)
            for underlay in grid_underlays:
                self.plot.underlays.remove(underlay)

        self.container.add(self.plot)
        self.container.request_redraw()

        self.selection_handler.flush()

    def _add_pan_tool_changed(self):
        """
        Method called when the Pan Tool checkbox is checked or unchecked.
        Adds the Pan Tool to the plot container if it isn't there and vice versa.
        """

        broadcaster = BroadcasterTool()
        for plot in self.container.components:
            if self.add_pan_tool:
                pan = PanTool(plot)
                broadcaster.tools.append(pan)
                self.container.tools.append(broadcaster)
            else:
                for tool in self.container.tools:
                    if isinstance(tool, BroadcasterTool):
                        self.container.tools.remove(tool)

    def _add_zoom_tool_changed(self):
        """
        Method called when the Zoom Tool checkbox is checked or unchecked.
        Adds the Zoom Tool to the plot container if it isn't there and vice versa.        
        """

        broadcaster = BroadcasterTool()
        for plot in self.container.components:
            if self.add_zoom_tool:
                pan = ZoomTool(plot)
                broadcaster.tools.append(pan)
                self.container.tools.append(broadcaster)
            else:
                for tool in self.container.tools:
                    if isinstance(tool, BroadcasterTool):
                        self.container.tools.remove(tool)

    def _add_dragzoom_changed(self):
        """
        Method called when the Pan Tool checkbox is checked or unchecked.
        Adds the Pan Tool to the plot container if it isn't there and vice versa.
        """

        broadcaster = BroadcasterTool()
        for plot in self.container.components:
            if self.add_dragzoom:
                pan = BetterSelectingZoom(
                    plot,
                    always_on=True,
                    tool_mode="box",
                    drag_button="left",
                    color="lightskyblue",
                    alpha=0.4,
                    border_color="dodgerblue",
                )
                broadcaster.tools.append(pan)
                self.container.tools.append(broadcaster)
            else:
                for tool in self.container.tools:
                    if isinstance(tool, BroadcasterTool):
                        self.container.tools.remove(tool)

    def _show_grid_changed(self):
        """
        Called when the Show grid checkbox is checked or unchecked. Adds a grid
        if one is not present and removes if present.
        """

        if not self.show_grid:
            for plot in self.container.components:
                for underlay in self.grid_underlays:
                    if underlay in plot.underlays:
                        plot.underlays.remove(underlay)
        else:
            for plot in self.container.components:
                for underlay in self.grid_underlays:
                    if underlay not in plot.underlays:
                        plot.underlays.append(underlay)

        self.container.request_redraw()

    def reassign_current_plot(self):
        """
        Reassigns the currently selected plot. 
        """

        self.selection_handler.create_selection()
        plot_index = self.selection_handler.selected_indices[0][0]
        plot_name = self.plot_list_view.keys()[plot_index]
        self.plot = self.plot_list_view[plot_name]
        self.selection_handler.flush()
コード例 #6
0
ファイル: chaco_plot.py プロジェクト: conkiztador/pdviper
class StackedPlot(ChacoPlot):
    offset = Range(0.0, 1.0, 0.015)
    value_range = Range(0.01, 1.05, 1.00)
    flip_order = Bool(False)

    def _get_traits_group(self):
        return VGroup(
            HGroup(Item("flip_order"), Item("offset"), Item("value_range")),
            UItem("component", editor=ComponentEditor()),
        )

    def __init__(self):
        super(StackedPlot, self).__init__()
        self.container = OverlayPlotContainer(
            bgcolor="white", use_backbuffer=True, border_visible=True, padding=50, padding_left=110, fill_padding=True
        )
        self.data = ArrayPlotData()
        self.chaco_plot = None
        self.value_mapper = None
        self.index_mapper = None
        self.x_axis = PlotAxis(
            component=self.container,
            orientation="bottom",
            title=u"Angle (2\u0398)",
            title_font=settings.axis_title_font,
            tick_label_font=settings.tick_font,
        )
        y_axis_title = "Normalized intensity (%s)" % get_value_scale_label("linear")
        self.y_axis = PlotAxis(
            component=self.container,
            orientation="left",
            title=y_axis_title,
            title_font=settings.axis_title_font,
            tick_label_font=settings.tick_font,
        )
        self.container.overlays.extend([self.x_axis, self.y_axis])
        self.container.tools.append(TraitsTool(self.container, classes=[LinePlot, PlotAxis]))
        self.colors = []
        self.last_flip_order = self.flip_order

    @on_trait_change("offset, value_range, flip_order")
    def _replot_data(self):
        self._plot(self.data_x, None, self.data_z, self.scale)
        self.container.request_redraw()

    def _prepare_data(self, datasets):
        interpolate = True
        stack = stack_datasets(datasets)
        if interpolate:
            (x, z) = interpolate_datasets(stack, points=4800)
            x = array([x] * len(datasets))
        else:
            x, z = map(np.transpose, np.transpose(stack))
        return x, None, z

    def _plot(self, x, y, z, scale):
        self.data_x, self.data_z, self.scale = x, z, scale
        if self.container.components:
            self.colors = map(lambda plot: plot.color, self.container.components)
            if self.last_flip_order != self.flip_order:
                self.colors.reverse()
            self.container.remove(*self.container.components)
        # Use a custom renderer so plot lines are clickable
        self.chaco_plot = Plot(self.data, renderer_map={"line": ClickableLinePlot})
        self.chaco_plot.bgcolor = "white"
        self.value_mapper = None
        self.index_mapper = None

        if len(self.data_x) == len(self.colors):
            colors = self.colors[:]
        else:
            colors = ["black"] * len(self.data_x)

        if self.flip_order:
            z = z[::-1]

        spacing = (z.max(axis=1) - z.min(axis=1)).min() * self.value_range
        offset = spacing * self.offset
        for i, (x_row, z_row) in enumerate(zip(x, z)):
            self.data.set_data("data_x_" + str(i), x_row)
            self.data.set_data("data_y_offset_" + str(i), z_row * self.value_range + offset * i)
            plots = self.chaco_plot.plot(("data_x_" + str(i), "data_y_offset_" + str(i)), color=colors[i], type="line")
            plot = plots[0]
            self.container.add(plot)
            # Required for double-clicking plots
            plot.index.sort_order = "ascending"
            plot.value.sort_order = "ascending"

            if self.value_mapper is None:
                self.index_mapper = plot.index_mapper
                self.value_mapper = plot.value_mapper
            else:
                plot.value_mapper = self.value_mapper
                self.value_mapper.range.add(plot.value)
                plot.index_mapper = self.index_mapper
                self.index_mapper.range.add(plot.index)
        range = self.value_mapper.range
        range.high = (range.high - range.low) * self.value_range + range.low
        self.x_axis.mapper = self.index_mapper
        self.y_axis.mapper = self.value_mapper
        self.y_axis.title = "Normalized intensity (%s)" % get_value_scale_label(scale)
        self.zoom_tool = ClickUndoZoomTool(
            plot,
            tool_mode="box",
            always_on=True,
            pointer="cross",
            drag_button=settings.zoom_button,
            undo_button=settings.undo_button,
        )
        plot.overlays.append(self.zoom_tool)
        self.last_flip_order = self.flip_order
        return self.container

    def _reset_view(self):
        self.zoom_tool.revert_history_all()
コード例 #7
0
class KMeansPlotHandler(HasTraits):
    '''
    Class for plotting the k-means clusters.
    
    '''

    # the data to cluster
    data = Array

    # the dataset created after preprocessing
    dataset = Array

    # the sklearn.cluster.KMeans object
    kmeans = Instance(KMeans)

    # Number of clusters
    n_clusters = Int

    # Maximum iterations for the clustering algorithm
    max_iter = Int

    # Container for the cluster plots
    container = Instance(OverlayPlotContainer)

    # the columns from the dataset to omit when performing clustering
    to_omit = List

    def __init__(self):
        self.kmeans = KMeans()
        self.container = OverlayPlotContainer()

    def create_dataset(self):
        '''
        Creates a numpy array from the current selection to pass to the 
        sklearn.cluster.kmeans object.
        '''

        if self.to_omit:
            if len(self.to_omit) > 0:
                n_rows = self.data.shape[0]
                n_cols = self.data.shape[1]
                to_omit = []
                for elem in self.to_omit:
                    if elem.isdigit():
                        to_omit.append(int(elem))

                dataset = self.data[:, 0].reshape((n_rows, 1))

                for elem in range(n_cols):
                    if elem not in to_omit:
                        if elem > 0:
                            dataset = np.hstack(
                                (dataset, self.data[:, elem].reshape(
                                    (n_rows, 1))))
            self.dataset = dataset

    def plot_clusters(self):
        '''
        Plots the clusters after calling the .fit method of the sklearn kmeans 
        estimator.
        '''

        self.kmeans.n_clusters = self.n_clusters
        self.kmeans.fit(self.dataset)

        # Reducing dimensions of the dataset and the cluster centers for
        # plottting
        pca = PCA(n_components=2, whiten=True)
        cluster_centers = pca.fit_transform(self.kmeans.cluster_centers_)
        dataset_red = pca.fit_transform(self.dataset)

        removed_components = []
        for component in self.container.components:
            removed_components.append(component)

        for component in removed_components:
            self.container.remove(component)

        for i in range(self.n_clusters):

            current_indices = find(self.kmeans.labels_ == i)
            current_data = dataset_red[current_indices, :]

            plotdata = ArrayPlotData(x=current_data[:, 0],
                                     y=current_data[:, 1])
            plot = Plot(plotdata)
            plot.plot(("x", "y"),
                      type='scatter',
                      color=tuple(COLOR_PALETTE[i]))
            self.container.add(plot)

        plotdata_cent = ArrayPlotData(x=cluster_centers[:, 0],
                                      y=cluster_centers[:, 1])
        plot_cent = Plot(plotdata_cent)
        plot_cent.plot(("x", "y"),
                       type='scatter',
                       marker='cross',
                       marker_size=8)
        self.container.add(plot_cent)

        self.container.request_redraw()
コード例 #8
0
class XYPlotHandler(HasTraits):
    '''
    Class for handling XY plots
    '''

    # Whether the data is a pandas dataframe
    AS_PANDAS_DATAFRAME = Bool

    # The container for all current plots. Gets updated everytime a plot is
    # added.
    container = OverlayPlotContainer

    # This can be removed.
    plotdata = ArrayPlotData

    # The current Plot object.
    plot = Plot

    # ColorTrait, mainly required for the TraitsUIItem view.
    color = ColorTrait("blue")

    # Marker trait for the view
    marker = marker_trait

    # Marker size trait
    marker_size = Int(4)

    # An instance of SelectionHandler for adding plots from the current
    # selection.
    selection_handler = Instance(SelectionHandler)

    # Bool traits for checking the type of the plot (discrete / continuous)
    plot_type_disc = Bool
    plot_type_cont = Bool

    # The data from which to draw the plots, same as the table attribute of
    # CsvModel
    table = Array

    # The pandas data frame if AS_PANDAS_DATAFRAME
    data_frame = Instance(DataFrame)

    # Contains the grid underlays of all the current plots
    grid_underlays = List

    # Used for viewing the list of the plots and the legend
    plot_list_view = Dict

    # TraitsUI view for plot properties, yet to find an enaml implementation
    view = View(Item('color'), Item('marker'), Item('marker_size'))

    # Trait that defines whether tools are present.
    add_pan_tool = Bool
    add_zoom_tool = Bool
    add_dragzoom = Bool

    # Whether grids and axes are visible
    show_grid = Bool

    def __init__(self):
        self.selection_handler = SelectionHandler()
        self.container = OverlayPlotContainer()
        self.underlays = []
        self.add_pan_tool = False
        self.add_zoom_tool = False
        self.add_dragzoom = False
        self.show_grid = False

    def add_xyplot_selection(self, plot_name):
        '''
        Called when the 'add plot from selection button is clicked.'
        '''

        self.selection_handler.create_selection()
        if self.selection_handler.xyplot_check():

            if self.AS_PANDAS_DATAFRAME:
                x_column = self.data_frame.columns[
                    self.selection_handler.selected_indices[0][1]]
                y_column = self.data_frame.columns[
                    self.selection_handler.selected_indices[1][1]]
                x = np.array(self.data_frame[x_column])
                y = np.array(self.data_frame[y_column])
                self.plotdata = ArrayPlotData(x=x, y=y)

            else:

                first_column = self.selection_handler.selected_indices[0]
                second_column = self.selection_handler.selected_indices[1]
                self.plotdata = ArrayPlotData(x=self.table[:, first_column[1]],
                                              y=self.table[:,
                                                           second_column[1]])

            plot = Plot(self.plotdata)

            if self.plot_type_disc:
                plot_type = 'scatter'
            else:
                plot_type = 'line'
            plot.plot(("x", "y"),
                      type=plot_type,
                      color=self.color,
                      marker=self.marker,
                      marker_size=self.marker_size)

            self.plot = plot

            for underlay in self.plot.underlays:
                if isinstance(underlay, PlotGrid):
                    if underlay not in self.grid_underlays:
                        self.grid_underlays.append(underlay)

            for underlay in self.grid_underlays:
                if underlay in self.plot.underlays:
                    self.plot.underlays.remove(underlay)

            if plot_name == '':
                self.plot_list_view['plot' +
                                    str(len(self.plot_list_view))] = self.plot
            else:
                self.plot_list_view[plot_name] = self.plot
            self.container.add(self.plot)

            self.container.request_redraw()

        self.selection_handler.flush()

    def grid_toggle(self, checked):
        '''
        Called when the 'Show Grid' checkbox ins toggled
        '''
        if not checked:
            for plot in self.container.components:
                for underlay in self.grid_underlays:
                    if underlay in plot.underlays:
                        plot.underlays.remove(underlay)
        else:
            for plot in self.container.components:
                for underlay in self.grid_underlays:
                    if underlay not in plot.underlays:
                        plot.underlays.append(underlay)
        self.container.request_redraw()

    def remove_selected_plots(self, selection):
        '''
        Called when the 'Remove Selected Plots' button is clicked
        '''
        remove_indices = []
        for model_index in selection:
            remove_indices.append(model_index[0].row)
        remove_plots = []
        for index in remove_indices:
            remove_plots.append(self.plot_list_view.keys()[index])

        removed_plots = []
        for plot in remove_plots:
            removed_plots.append(self.plot_list_view.pop(plot))
        for plot in self.container.components:
            self.container.remove(plot)
        for plot in self.plot_list_view.keys():
            self.container.add(self.plot_list_view[plot])
        self.container.request_redraw()

    def edit_selection(self, show_grid, plot_visible, plot_type_disc):
        '''
        Called to start editing the selected plot. Should accompany the 'Edit
        Plot' dialog.
        '''

        #self.selection_handler.create_selection()
        #index = self.selection_handler.selected_indices[0][0]
        #plot_name = self.plot_list_view.keys()[index]
        #plot = self.plot_list_view[plot_name]

        self.container.remove(self.plot)

        self.plot_type_disc = plot_type_disc

        if self.plot_type_disc:
            plot_type = 'scatter'
        else:
            plot_type = 'line'

        plot = Plot(self.plot.data)
        plot.plot(("x", "y"),
                  color=self.color,
                  type=plot_type,
                  marker=self.marker,
                  marker_size=self.marker_size)
        self.plot = plot
        self.plot.visible = plot_visible

        grid_underlays = []

        if not show_grid:
            for underlay in self.plot.underlays:
                if isinstance(underlay, PlotGrid):
                    grid_underlays.append(underlay)
            for underlay in grid_underlays:
                self.plot.underlays.remove(underlay)

        self.container.add(self.plot)
        self.container.request_redraw()

        self.selection_handler.flush()

    def _add_pan_tool_changed(self):
        '''
        Method called when the Pan Tool checkbox is checked or unchecked.
        Adds the Pan Tool to the plot container if it isn't there and vice versa.
        '''

        broadcaster = BroadcasterTool()
        for plot in self.container.components:
            if self.add_pan_tool:
                pan = PanTool(plot)
                broadcaster.tools.append(pan)
                self.container.tools.append(broadcaster)
            else:
                for tool in self.container.tools:
                    if isinstance(tool, BroadcasterTool):
                        self.container.tools.remove(tool)

    def _add_zoom_tool_changed(self):
        '''
        Method called when the Zoom Tool checkbox is checked or unchecked.
        Adds the Zoom Tool to the plot container if it isn't there and vice versa.        
        '''

        broadcaster = BroadcasterTool()
        for plot in self.container.components:
            if self.add_zoom_tool:
                pan = ZoomTool(plot)
                broadcaster.tools.append(pan)
                self.container.tools.append(broadcaster)
            else:
                for tool in self.container.tools:
                    if isinstance(tool, BroadcasterTool):
                        self.container.tools.remove(tool)

    def _add_dragzoom_changed(self):
        '''
        Method called when the Pan Tool checkbox is checked or unchecked.
        Adds the Pan Tool to the plot container if it isn't there and vice versa.
        '''

        broadcaster = BroadcasterTool()
        for plot in self.container.components:
            if self.add_dragzoom:
                pan = BetterSelectingZoom(plot,
                                          always_on=True,
                                          tool_mode='box',
                                          drag_button='left',
                                          color='lightskyblue',
                                          alpha=0.4,
                                          border_color='dodgerblue')
                broadcaster.tools.append(pan)
                self.container.tools.append(broadcaster)
            else:
                for tool in self.container.tools:
                    if isinstance(tool, BroadcasterTool):
                        self.container.tools.remove(tool)

    def _show_grid_changed(self):
        '''
        Called when the Show grid checkbox is checked or unchecked. Adds a grid
        if one is not present and removes if present.
        '''

        if not self.show_grid:
            for plot in self.container.components:
                for underlay in self.grid_underlays:
                    if underlay in plot.underlays:
                        plot.underlays.remove(underlay)
        else:
            for plot in self.container.components:
                for underlay in self.grid_underlays:
                    if underlay not in plot.underlays:
                        plot.underlays.append(underlay)

        self.container.request_redraw()

    def reassign_current_plot(self):
        '''
        Reassigns the currently selected plot. 
        '''

        self.selection_handler.create_selection()
        plot_index = self.selection_handler.selected_indices[0][0]
        plot_name = self.plot_list_view.keys()[plot_index]
        self.plot = self.plot_list_view[plot_name]
        self.selection_handler.flush()
コード例 #9
0
ファイル: chaco_plot.py プロジェクト: adam-urbanczyk/pdviper
class StackedPlot(ChacoPlot):
    offset = Range(0.0, 1.0, 0.015)
    value_range = Range(0.01, 1.05, 1.00)
    flip_order = Bool(False)

    def _get_traits_group(self):
        return VGroup(
                   HGroup(
                       Item('flip_order'),
                       Item('offset'),
                       Item('value_range'),
                   ),
                   UItem('component', editor=ComponentEditor()),
               )

    def __init__(self):
        super(StackedPlot, self).__init__()
        self.container = OverlayPlotContainer(bgcolor='white',
                                         use_backbuffer=True,
                                         border_visible=True,
                                         padding=50,
                                         padding_left=110,
                                         fill_padding=True
                                             )
        self.data = ArrayPlotData()
        self.chaco_plot = None
        self.value_mapper = None
        self.index_mapper = None
        self.x_axis = MyPlotAxis(component=self.container,
                          orientation='bottom',
                          title=u'Angle (2\u0398)',
                          title_font=settings.axis_title_font,
                          tick_label_font=settings.tick_font)
        y_axis_title = 'Normalized intensity (%s)' % get_value_scale_label('linear')
        self.y_axis = MyPlotAxis(component=self.container,
                          orientation='left',
                          title=y_axis_title,
                          title_font=settings.axis_title_font,
                          tick_label_font=settings.tick_font)
        self.container.overlays.extend([self.x_axis, self.y_axis])
        self.container.tools.append(
            TraitsTool(self.container, classes=[LinePlot,MyPlotAxis]))
        self.colors = []
        self.last_flip_order = self.flip_order

    @on_trait_change('offset, value_range, flip_order')
    def _replot_data(self):
        self._plot(self.data_x, None, self.data_z, self.scale)
        self.container.request_redraw()

#    def _prepare_data(self, datasets):
    def _prepare_data(self, stack):
        # stack = stack_datasets(datasets)
        x = stack[:,:,0]
        z = stack[:,:,2]
        return x, None, z

    def _plot(self, x, y, z, scale):
        self.data_x, self.data_z, self.scale = x, z, scale
        if self.container.components:
            self.colors = map(lambda plot: plot.color, self.container.components)
            if self.last_flip_order != self.flip_order:
                self.colors.reverse()
            self.container.remove(*self.container.components)
        # Use a custom renderer so plot lines are clickable
        self.chaco_plot = Plot(self.data,
                               renderer_map={ 'line': ClickableLinePlot })
        self.chaco_plot.bgcolor = 'white'
        self.value_mapper = None
        self.index_mapper = None

        if len(self.data_x) == len(self.colors):
            colors = self.colors[:]
        else:
            colors = ['black'] * len(self.data_x)

        if self.flip_order:
            z = z[::-1]

        spacing = (z.max(axis=1) - z.min(axis=1)).min() * self.value_range
        offset = spacing * self.offset
        for i, (x_row, z_row) in enumerate(zip(x, z)):
            self.data.set_data('data_x_' + str(i), x_row)
            self.data.set_data('data_y_offset_' + str(i), z_row * self.value_range + offset * i)
            plots = self.chaco_plot.plot(('data_x_' + str(i), 'data_y_offset_' + str(i)), color=colors[i], type='line')
            plot = plots[0]
            self.container.add(plot)
            # Required for double-clicking plots
            plot.index.sort_order = 'ascending'
            plot.value.sort_order = 'ascending'

            if self.value_mapper is None:
                self.index_mapper = plot.index_mapper
                self.value_mapper = plot.value_mapper
            else:
                plot.value_mapper = self.value_mapper
                self.value_mapper.range.add(plot.value)
                plot.index_mapper = self.index_mapper
                self.index_mapper.range.add(plot.index)
        range = self.value_mapper.range
        range.high = (range.high - range.low) * self.value_range + range.low
        self.x_axis.mapper = self.index_mapper
        self.y_axis.mapper = self.value_mapper
        self.y_axis.title = 'Normalized intensity (%s)' % \
                get_value_scale_label(scale)
        self.zoom_tool = ClickUndoZoomTool(
            plot, tool_mode="box", always_on=True, pointer="cross",
            drag_button=settings.zoom_button,
            undo_button=settings.undo_button,
        )
        plot.overlays.append(self.zoom_tool)
        self.last_flip_order = self.flip_order
        return self.container

    def _reset_view(self):
        self.zoom_tool.revert_history_all()
コード例 #10
0
ファイル: ucc.py プロジェクト: Gazworth/hyperspy
class TemplatePicker(HasTraits):
    template = Array
    CC = Array
    peaks = List
    zero=Int(0)
    tmp_size = Range(low=2, high=512, value=64, cols=4)
    max_pos_x=Property(depends_on=['tmp_size'])
    max_pos_y=Property(depends_on=['tmp_size'])
    top = Range(low='zero',high='max_pos_y', value=20, cols=4)
    left = Range(low='zero',high='max_pos_x', value=20, cols=4)
    is_square = Bool
    img_plot = Instance(Plot)
    tmp_plot = Instance(Plot)
    findpeaks = Button
    next_img = Button
    prev_img = Button
    peak_width = Range(low=2, high=200, value=10)
    tab_selected = Event
    ShowCC = Bool
    img_container = Instance(Component)
    container = Instance(Component)
    colorbar= Instance(Component)
    numpeaks_total = Int(0,cols=5)
    numpeaks_img = Int(0,cols=5)
    OK_custom=OK_custom_handler
    cbar_selection = Instance(RangeSelection)
    cbar_selected = Event
    thresh=Trait(None,None,List,Tuple,Array)
    thresh_upper=Float(1.0)
    thresh_lower=Float(0.0)
    numfiles=Int(1)
    img_idx=Int(0)
    tmp_img_idx=Int(0)

    csr=Instance(BaseCursorTool)

    traits_view = View(
        Group(
            Group(
                Item("img_container",editor=ComponentEditor(), show_label=False),
                HGroup(
                    Item("ShowCC", editor=BooleanEditor(), label="Show cross correlation image"),
                    Spring(),
                    Item("prev_img",editor=ButtonEditor(label="<"),show_label=False, enabled_when='numfiles > 1'),
                    Item("next_img",editor=ButtonEditor(label=">"),show_label=False, enabled_when='numfiles > 1'),
                    ),
                label="Original image", show_border=True, trait_modified="tab_selected",
                orientation='vertical',),
            VGroup(
                Group(
                    HGroup(
                        Item("left", label="Left coordinate", style="custom"),
                        Spring(),
                        Item("top", label="Top coordinate", style="custom"),
                        ),
                    Item("tmp_size", label="Template size", style="custom"),
                    Item("tmp_plot",editor=ComponentEditor(height=256, width=256), show_label=False, resizable=True),
                    label="Template", show_border=True),
                Group(
                    HGroup(
                        Item("peak_width", label="Peak width", style="custom"),
                        Spring(),
                        Item("findpeaks",editor=ButtonEditor(label="Find Peaks"),show_label=False),
                        ),
                    HGroup(
                        Item("thresh_lower",label="Threshold Lower Value", editor=TextEditor(evaluate=float,
                                                                                             format_str='%1.4f')),
                        Spring(),
                        Item("thresh_upper",label="Threshold Upper Value", editor=TextEditor(evaluate=float,
                                                                                             format_str='%1.4f')),
                    ),
                    HGroup(
                        Item("numpeaks_img",label="Number of Cells selected (this image)",style='readonly'),
                        Spring(),
                        Item("numpeaks_total",label="Total",style='readonly'),                          
                        ),
                    label="Peak parameters", show_border=True),
                ),
            orientation='horizontal'),
        buttons = [ Action(name='OK', enabled_when = 'numpeaks_total > 0' ),
            CancelButton ],
        title="Template Picker",
        handler=OK_custom, kind='livemodal',
        key_bindings = key_bindings,
        width=940, height=530,resizable=True)        

    def __init__(self, signal_instance):
        super(TemplatePicker, self).__init__()
        try:
            import cv
        except:
            try:
                import cv2.cv as cv
            except:
                print "OpenCV unavailable.  Can't do cross correlation without it.  Aborting."
                return None
        self.OK_custom=OK_custom_handler()
        self.sig=signal_instance
        if not hasattr(self.sig.mapped_parameters,"original_files"):
            self.titles=[os.path.splitext(self.sig.mapped_parameters.title)[0]]
        else:
            self.numfiles=len(self.sig.mapped_parameters.original_files.keys())
            self.titles=self.sig.mapped_parameters.original_files.keys()
        tmp_plot_data=ArrayPlotData(imagedata=self.sig.data[self.img_idx,self.top:self.top+self.tmp_size,self.left:self.left+self.tmp_size])
        tmp_plot=Plot(tmp_plot_data,default_origin="top left")
        tmp_plot.img_plot("imagedata", colormap=jet)
        tmp_plot.aspect_ratio=1.0
        self.tmp_plot=tmp_plot
        self.tmp_plotdata=tmp_plot_data
        self.img_plotdata=ArrayPlotData(imagedata=self.sig.data[self.img_idx,:,:])
        self.img_container=self._image_plot_container()

        self.crop_sig=None

    def render_image(self):
        plot = Plot(self.img_plotdata,default_origin="top left")
        img=plot.img_plot("imagedata", colormap=gray)[0]
        plot.title="%s of %s: "%(self.img_idx+1,self.numfiles)+self.titles[self.img_idx]
        plot.aspect_ratio=float(self.sig.data.shape[2])/float(self.sig.data.shape[1])

        csr = CursorTool(img, drag_button='left', color='white',
                         line_width=2.0)
        self.csr=csr
        csr.current_position=self.left, self.top
        img.overlays.append(csr)

        # attach the rectangle tool
        plot.tools.append(PanTool(plot,drag_button="right"))
        zoom = ZoomTool(plot, tool_mode="box", always_on=False, aspect_ratio=plot.aspect_ratio)
        plot.overlays.append(zoom)
        self.img_plot=plot
        return plot

    def render_scatplot(self):
        peakdata=ArrayPlotData()
        peakdata.set_data("index",self.peaks[self.img_idx][:,0])
        peakdata.set_data("value",self.peaks[self.img_idx][:,1])
        peakdata.set_data("color",self.peaks[self.img_idx][:,2])
        scatplot=Plot(peakdata,aspect_ratio=self.img_plot.aspect_ratio,default_origin="top left")
        scatplot.plot(("index", "value", "color"),
                      type="cmap_scatter",
                      name="my_plot",
                      color_mapper=jet(DataRange1D(low = 0.0,
                                       high = 1.0)),
                      marker = "circle",
                      fill_alpha = 0.5,
                      marker_size = 6,
                      )
        scatplot.x_grid.visible = False
        scatplot.y_grid.visible = False
        scatplot.range2d=self.img_plot.range2d
        self.scatplot=scatplot
        self.peakdata=peakdata
        return scatplot

    def _image_plot_container(self):
        plot = self.render_image()

        # Create a container to position the plot and the colorbar side-by-side
        self.container=OverlayPlotContainer()
        self.container.add(plot)
        self.img_container = HPlotContainer(use_backbuffer = False)
        self.img_container.add(self.container)
        self.img_container.bgcolor = "white"

        if self.numpeaks_img>0:
            scatplot = self.render_scatplot()
            self.container.add(scatplot)
            colorbar = self.draw_colorbar()
            self.img_container.add(colorbar)
        return self.img_container

    def draw_colorbar(self):
        scatplot=self.scatplot
        cmap_renderer = scatplot.plots["my_plot"][0]
        selection = ColormappedSelectionOverlay(cmap_renderer, fade_alpha=0.35, 
                                                selection_type="range")
        cmap_renderer.overlays.append(selection)
        if self.thresh is not None:
            cmap_renderer.color_data.metadata['selections']=self.thresh
            cmap_renderer.color_data.metadata_changed={'selections':self.thresh}
        # Create the colorbar, handing in the appropriate range and colormap
        colormap=scatplot.color_mapper
        colorbar = ColorBar(index_mapper=LinearMapper(range=DataRange1D(low = 0.0,
                                       high = 1.0)),
                        orientation='v',
                        resizable='v',
                        width=30,
                        padding=20)
        colorbar_selection=RangeSelection(component=colorbar)
        colorbar.tools.append(colorbar_selection)
        ovr=colorbar.overlays.append(RangeSelectionOverlay(component=colorbar,
                                                   border_color="white",
                                                   alpha=0.8,
                                                   fill_color="lightgray",
                                                   metadata_name='selections'))
        #ipshell('colorbar, colorbar_selection and ovr available:')
        self.cbar_selection=colorbar_selection
        self.cmap_renderer=cmap_renderer
        colorbar.plot = cmap_renderer
        colorbar.padding_top = scatplot.padding_top
        colorbar.padding_bottom = scatplot.padding_bottom
        self.colorbar=colorbar
        return colorbar

    @on_trait_change('ShowCC')
    def toggle_cc_view(self):
        if self.ShowCC:
            self.update_CC()
            grid_data_source = self.img_plot.range2d.sources[0]
            grid_data_source.set_data(np.arange(self.CC.shape[1]), 
                                      np.arange(self.CC.shape[0]))
        else:
            self.img_plotdata.set_data("imagedata",self.sig.data[self.img_idx,:,:])
        self.redraw_plots()

    @on_trait_change("img_idx")
    def update_img_depth(self):
        if self.ShowCC:
            self.update_CC()
        else:
            self.img_plotdata.set_data("imagedata",self.sig.data[self.img_idx,:,:])
        self.img_plot.title="%s of %s: "%(self.img_idx+1,self.numfiles)+self.titles[self.img_idx]
        self.redraw_plots()        

    def _get_max_pos_x(self):
        max_pos_x=self.sig.data.shape[-1]-self.tmp_size-1
        if max_pos_x>0:
            return max_pos_x
        else:
            return None

    def _get_max_pos_y(self):
        max_pos_y=self.sig.data.shape[-2]-self.tmp_size-1
        if max_pos_y>0:
            return max_pos_y
        else:
            return None

    @on_trait_change('next_img')
    def increase_img_idx(self,info):
        if self.img_idx==(self.numfiles-1):
            self.img_idx=0
        else:
            self.img_idx+=1

    @on_trait_change('prev_img')
    def decrease_img_idx(self,info):
        if self.img_idx==0:
            self.img_idx=self.numfiles-1
        else:
            self.img_idx-=1

    @on_trait_change('left, top')
    def update_csr_position(self):
        if self.left>0:        
            self.csr.current_position=self.left,self.top

    @on_trait_change('csr:current_position')
    def update_top_left(self):
        if self.csr.current_position[0]>0 or self.csr.current_position[1]>0:
            if self.csr.current_position[0]>self.max_pos_x:
                if self.csr.current_position[1]<self.max_pos_y:
                    self.top=self.csr.current_position[1]
                else:
                    self.csr.current_position=self.max_pos_x, self.max_pos_y
            elif self.csr.current_position[1]>self.max_pos_y:
                self.left,self.top=self.csr.current_position[0],self.max_pos_y
            else:
                self.left,self.top=self.csr.current_position
        
    @on_trait_change('left, top, tmp_size')
    def update_tmp_plot(self):
        self.tmp_plotdata.set_data("imagedata", 
                                   self.sig.data[self.img_idx,self.top:self.top+self.tmp_size,self.left:self.left+self.tmp_size])
        grid_data_source = self.tmp_plot.range2d.sources[0]
        grid_data_source.set_data(np.arange(self.tmp_size), np.arange(self.tmp_size))
        self.tmp_img_idx=self.img_idx
        if self.numpeaks_total>0:
            print "clearing peaks"
            self.peaks=[np.array([[0,0,-1]])]
        return

    @on_trait_change('left, top, tmp_size')
    def update_CC(self):
        if self.ShowCC:
            self.CC = cv_funcs.xcorr(self.sig.data[self.tmp_img_idx,self.top:self.top+self.tmp_size,
                                                   self.left:self.left+self.tmp_size],
                                     self.sig.data[self.img_idx,:,:])
            self.img_plotdata.set_data("imagedata",self.CC)

    @on_trait_change('cbar_selection:selection')
    def update_thresh(self):
        try:
            thresh=self.cbar_selection.selection
            self.thresh=thresh
            self.cmap_renderer.color_data.metadata['selections']=thresh
            self.thresh_lower=thresh[0]
            self.thresh_upper=thresh[1]
            #cmap_renderer.color_data.metadata['selection_masks']=self.thresh
            self.cmap_renderer.color_data.metadata_changed={'selections':thresh}
            self.container.request_redraw()
            self.img_container.request_redraw()
        except:
            pass

    @on_trait_change('thresh_upper,thresh_lower')
    def manual_thresh_update(self):
        self.thresh=[self.thresh_lower,self.thresh_upper]
        self.cmap_renderer.color_data.metadata['selections']=self.thresh
        self.cmap_renderer.color_data.metadata_changed={'selections':self.thresh}
        self.container.request_redraw()
        self.img_container.request_redraw()

    @on_trait_change('peaks,cbar_selection:selection,img_idx')
    def calc_numpeaks(self):
        try:
            thresh=self.cbar_selection.selection
            self.thresh=thresh
        except:
            thresh=[]
        if thresh==[] or thresh==() or thresh==None:
            thresh=(0,1)
        self.numpeaks_total=int(np.sum([np.sum(np.ma.masked_inside(self.peaks[i][:,2],thresh[0],thresh[1]).mask) for i in xrange(len(self.peaks))]))
        try:
            self.numpeaks_img=int(np.sum(np.ma.masked_inside(self.peaks[self.img_idx][:,2],thresh[0],thresh[1]).mask))
        except:
            self.numpeaks_img=0

    @on_trait_change('findpeaks')
    def locate_peaks(self):
        from hyperspy import peak_char as pc
        peaks=[]
        """from hyperspy.misc.progressbar import ProgressBar, \
            Percentage, RotatingMarker, ETA, Bar
        widgets = ['Locating peaks: ', Percentage(), ' ', Bar(marker=RotatingMarker()),
                   ' ', ETA()]
        pbar = ProgressBar(widgets=widgets, maxval=100).start()"""
        progress = ProgressDialog(title="Peak finder progress", message="Finding peaks on %s images"%self.numfiles, max=self.numfiles, show_time=True, can_cancel=False)
        progress.open()
        for idx in xrange(self.numfiles):
            #pbar.update(float(idx)/self.numfiles*100)
            self.CC = cv_funcs.xcorr(self.sig.data[self.tmp_img_idx,
                                                   self.top:self.top+self.tmp_size,
                                                   self.left:self.left+self.tmp_size],
                                               self.sig.data[idx,:,:])
            # peak finder needs peaks greater than 1.  Multiply by 255 to scale them.
            pks=pc.two_dim_findpeaks(self.CC*255, peak_width=self.peak_width, medfilt_radius=None)
            pks[:,2]=pks[:,2]/255.
            peaks.append(pks)
            progress.update(idx+1)
        #pbar.finish()
        self.peaks=peaks
        
    def mask_peaks(self,idx):
        thresh=self.cbar_selection.selection
        if thresh==[]:
            thresh=(0,1)
        mpeaks=np.ma.asarray(self.peaks[idx])
        mpeaks[:,2]=np.ma.masked_outside(mpeaks[:,2],thresh[0],thresh[1])
        return mpeaks

    @on_trait_change("peaks")
    def redraw_plots(self):
        oldplot=self.img_plot
        self.container.remove(oldplot)
        newplot=self.render_image()
        self.container.add(newplot)
        self.img_plot=newplot

        try:
            # if these haven't been created before, this will fail.  wrap in try to prevent that.
            oldscat=self.scatplot
            self.container.remove(oldscat)
            oldcolorbar = self.colorbar
            self.img_container.remove(oldcolorbar)
        except:
            pass

        if self.numpeaks_img>0:
            newscat=self.render_scatplot()
            self.container.add(newscat)
            self.scatplot=newscat
            colorbar = self.draw_colorbar()
            self.img_container.add(colorbar)
            self.colorbar=colorbar

        self.container.request_redraw()
        self.img_container.request_redraw()

    def crop_cells_stack(self):
        from hyperspy.signals.aggregate import AggregateCells
        if self.numfiles==1:
            self.crop_sig=self.crop_cells()
            return
        else:
            crop_agg=[]
            for idx in xrange(self.numfiles):
                peaks=np.ma.compress_rows(self.mask_peaks(idx))
                if peaks.any():
                    crop_agg.append(self.crop_cells(idx))
            self.crop_sig=AggregateCells(*crop_agg)
            return

    def crop_cells(self,idx=0):
        print "cropping cells..."
        from hyperspy.signals.image import Image
        # filter the peaks that are outside the selected threshold
        peaks=np.ma.compress_rows(self.mask_peaks(idx))
        tmp_sz=self.tmp_size
        data=np.zeros((peaks.shape[0],tmp_sz,tmp_sz))
        if not hasattr(self.sig.mapped_parameters,"original_files"):
            parent=self.sig
        else:
            parent=self.sig.mapped_parameters.original_files[self.titles[idx]]
        pmp=parent.mapped_parameters
        positions=np.zeros((peaks.shape[0],1),dtype=[('filename','a256'),('id','i4'),('position','f4',(1,2))])
        for i in xrange(peaks.shape[0]):
            # crop the cells from the given locations
            data[i,:,:]=self.sig.data[idx,peaks[i,1]:peaks[i,1]+tmp_sz,peaks[i,0]:peaks[i,0]+tmp_sz]
            positions[i]=(self.titles[idx],i,peaks[i,:2])
            crop_sig=Image({'data':data,
                            'mapped_parameters':{
                               'title':'Cropped cells from %s'%self.titles[idx],
                               'record_by':'image',
                               'locations':positions,
                               'original_files':{pmp.title:parent},
                               }
                            })
            
        return crop_sig
コード例 #11
0
ファイル: chaco_plot.py プロジェクト: conkiztador/pdviper
class StackedPlot(ChacoPlot):
    offset = Range(0.0, 1.0, 0.015)
    value_range = Range(0.01, 1.05, 1.00)
    flip_order = Bool(False)

    def _get_traits_group(self):
        return VGroup(
                   HGroup(
                       Item('flip_order'),
                       Item('offset'),
                       Item('value_range'),
                   ),
                   UItem('component', editor=ComponentEditor()),
               )

    def __init__(self):
        super(StackedPlot, self).__init__()
        self.container = OverlayPlotContainer(bgcolor='white',
                                         use_backbuffer=True,
                                         border_visible=True,
                                         padding=50,
                                         padding_left=110,
                                         fill_padding=True
                                             )
        self.data = ArrayPlotData()
        self.chaco_plot = None
        self.value_mapper = None
        self.index_mapper = None
        self.x_axis = PlotAxis(component=self.container,
                          orientation='bottom',
                          title=u'Angle (2\u0398)',
                          title_font=settings.axis_title_font,
                          tick_label_font=settings.tick_font)
        y_axis_title = 'Normalized intensity (%s)' % get_value_scale_label('linear')
        self.y_axis = PlotAxis(component=self.container,
                          orientation='left',
                          title=y_axis_title,
                          title_font=settings.axis_title_font,
                          tick_label_font=settings.tick_font)
        self.container.overlays.extend([self.x_axis, self.y_axis])
        self.container.tools.append(
            TraitsTool(self.container, classes=[LinePlot,PlotAxis]))
        self.colors = []
        self.last_flip_order = self.flip_order

    @on_trait_change('offset, value_range, flip_order')
    def _replot_data(self):
        self._plot(self.data_x, None, self.data_z, self.scale)
        self.container.request_redraw()

    def _prepare_data(self, datasets):
        interpolate = True
        stack = stack_datasets(datasets)
        if interpolate:
            (x, z) = interpolate_datasets(stack, points=4800)
            x = array([x] * len(datasets))
        else:
            x, z = map(np.transpose, np.transpose(stack))
        return x, None, z

    def _plot(self, x, y, z, scale):
        self.data_x, self.data_z, self.scale = x, z, scale
        if self.container.components:
            self.colors = map(lambda plot: plot.color, self.container.components)
            if self.last_flip_order != self.flip_order:
                self.colors.reverse()
            self.container.remove(*self.container.components)
        # Use a custom renderer so plot lines are clickable
        self.chaco_plot = Plot(self.data,
                               renderer_map={ 'line': ClickableLinePlot })
        self.chaco_plot.bgcolor = 'white'
        self.value_mapper = None
        self.index_mapper = None

        if len(self.data_x) == len(self.colors):
            colors = self.colors[:]
        else:
            colors = ['black'] * len(self.data_x)

        if self.flip_order:
            z = z[::-1]

        spacing = (z.max(axis=1) - z.min(axis=1)).min() * self.value_range
        offset = spacing * self.offset
        for i, (x_row, z_row) in enumerate(zip(x, z)):
            self.data.set_data('data_x_' + str(i), x_row)
            self.data.set_data('data_y_offset_' + str(i), z_row * self.value_range + offset * i)
            plots = self.chaco_plot.plot(('data_x_' + str(i), 'data_y_offset_' + str(i)), color=colors[i], type='line')
            plot = plots[0]
            self.container.add(plot)
            # Required for double-clicking plots
            plot.index.sort_order = 'ascending'
            plot.value.sort_order = 'ascending'

            if self.value_mapper is None:
                self.index_mapper = plot.index_mapper
                self.value_mapper = plot.value_mapper
            else:
                plot.value_mapper = self.value_mapper
                self.value_mapper.range.add(plot.value)
                plot.index_mapper = self.index_mapper
                self.index_mapper.range.add(plot.index)
        range = self.value_mapper.range
        range.high = (range.high - range.low) * self.value_range + range.low
        self.x_axis.mapper = self.index_mapper
        self.y_axis.mapper = self.value_mapper
        self.y_axis.title = 'Normalized intensity (%s)' % \
                get_value_scale_label(scale)
        self.zoom_tool = ClickUndoZoomTool(
            plot, tool_mode="box", always_on=True, pointer="cross",
            drag_button=settings.zoom_button,
            undo_button=settings.undo_button,
        )
        plot.overlays.append(self.zoom_tool)
        self.last_flip_order = self.flip_order
        return self.container

    def _reset_view(self):
        self.zoom_tool.revert_history_all()