コード例 #1
0
class SelectFromCollection(object):
    
    def __init__(self, ax, collection, mmc, img):
        self.colornormalizer = Normalize(vmin=0, vmax=1, clip=False)
        self.scat = plt.scatter(img[:, 0], img[:, 1], c=mmc.classvec)
        plt.gray()
        plt.setp(ax.get_yticklabels(), visible=False)
        ax.yaxis.set_tick_params(size=0)
        plt.setp(ax.get_xticklabels(), visible=False)
        ax.xaxis.set_tick_params(size=0)
        self.img = img
        self.canvas = ax.figure.canvas
        self.collection = collection
        #self.alpha_other = alpha_other
        self.mmc = mmc
        self.prevnewclazz = None

        self.xys = collection
        self.Npts = len(self.xys)
        
        self.lockedset = set([])

        self.lasso = LassoSelector(ax, onselect=self.onselect)#, lineprops = {:'prism'})
        self.lasso.disconnect_events()
        self.lasso.connect_event('button_press_event', self.lasso.onpress)
        self.lasso.connect_event('button_release_event', self.onrelease)
        self.lasso.connect_event('motion_notify_event', self.lasso.onmove)
        self.lasso.connect_event('draw_event', self.lasso.update_background)
        self.lasso.connect_event('key_press_event', self.onkeypressed)
        #self.lasso.connect_event('button_release_event', self.onrelease)
        self.ind = []
        self.slider_axis = plt.axes(slider_coords, visible = False)
        self.slider_axis2 = plt.axes(obj_fun_display_coords, visible = False)
        self.in_selection_slider = None
        newws = list(set(range(len(self.collection))) - self.lockedset)
        self.mmc.new_working_set(newws)
        self.lasso.line.set_visible(False)
    
    def onselect(self, verts):
        self.path = Path(verts)
        self.ind = np.nonzero(self.path.contains_points(self.xys))[0]
        print 'Selected '+str(len(self.ind))+' points'
        newws = list(set(self.ind) - self.lockedset)
        self.mmc.new_working_set(newws)
        self.redrawall()
    
    def onpress(self, event):
        if self.lasso.ignore(event) or event.inaxes != self.ax:
            return
        self.lasso.line.set_data([[], []])
        self.lasso.verts = [(event.xdata, event.ydata)]
        self.lasso.line.set_visible(True)

    def onrelease(self, event):
        if self.lasso.ignore(event):
            return
        if self.lasso.verts is not None:
            if event.inaxes == self.lasso.ax:
                self.lasso.verts.append((event.xdata, event.ydata))
            self.lasso.onselect(self.lasso.verts)
        self.lasso.verts = None
    
    def onkeypressed(self, event):
        print 'You pressed', event.key
        if event.key == '1':
            print 'Assigned all selected points to class 1'
            newclazz = 1
            mmc.claim_all_points_in_working_set(newclazz)
        if event.key == '0':
            print 'Assigned all selected points to class 0'
            newclazz = 0
            mmc.claim_all_points_in_working_set(newclazz)
        if event.key == 'a':
            print 'Selected all points'
            newws = list(set(range(len(self.collection))) - self.lockedset)
            self.mmc.new_working_set(newws)
            self.lasso.line.set_visible(False)
        if event.key == 'c':
            changecount = mmc.cyclic_descent_in_working_set()
            print 'Performed ', changecount, 'cyclic descent steps'
        if event.key == 'l':
            print 'Locked the class labels of selected points'
            self.lockedset = self.lockedset | set(self.ind)
            newws = list(set(self.ind) - self.lockedset)
            self.mmc.new_working_set(newws)
            #print newws
        if event.key == 'u':
            print 'Unlocked the selected points'
            self.lockedset = self.lockedset - set(self.ind)
            newws = list(set(self.ind) - self.lockedset)
            self.mmc.new_working_set(newws)
        self.redrawall()
    
    def redrawall(self, newslider = True):
        if newslider:
            nozeros = np.nonzero(self.mmc.classvec_ws)[0]
            self.slider_axis.cla()
            del self.slider_axis
            del self.slider_axis2
            self.slider_axis = plt.axes(slider_coords)
            self.slider_axis2 = plt.axes(obj_fun_display_coords)
            steepness_vector = mmc.compute_steepness_vector()
            X = [steepness_vector, steepness_vector]
            #right = left+width
            #self.slider_axis2.imshow(X, interpolation='bicubic', cmap=plt.get_cmap("Oranges"), alpha=1)
            self.slider_axis2.imshow(X, cmap=plt.get_cmap("Oranges"))
            self.slider_axis2.set_aspect('auto')
            plt.setp(self.slider_axis2.get_yticklabels(), visible=False)
            self.slider_axis2.yaxis.set_tick_params(size=0)
            del self.in_selection_slider
            self.in_selection_slider = None
            self.in_selection_slider = Slider(self.slider_axis, 'Fraction slider', 0., len(mmc.working_set), valinit=len(nozeros))
            def sliderupdate(val):
                val = int(val)
                nonzeroc = len(np.nonzero(self.mmc.classvec_ws)[0])
                if val > nonzeroc:
                    claims = val - nonzeroc
                    newclazz = 1
                elif val < nonzeroc:
                    claims = nonzeroc - val
                    newclazz = 0
                else: return
                print 'Claimed', claims, 'points for class', newclazz   #val, nonzeroc, claims
                self.claims = claims
                mmc.claim_n_points(claims, newclazz)
                steepness_vector = mmc.compute_steepness_vector()
                X = [steepness_vector, steepness_vector]
                self.slider_axis2.imshow(X, cmap=plt.get_cmap("Oranges"))
                self.slider_axis2.set_aspect('auto')
                self.redrawall(newslider = False) #HACK!
                self.prevnewclazz = newclazz
            self.in_selection_slider.on_changed(sliderupdate)
        oneclazz = np.nonzero(self.mmc.classvec)[0]
        col_row = self.collection[oneclazz]
        rowcs, colcs = col_row[:, 1], col_row[:, 0]
        #self.img[rowcs, colcs, :] = 0
        #self.img[rowcs, colcs, 0] = 255
        zeroclazz = np.nonzero(self.mmc.classvec - 1)[0]
        col_row = self.collection[zeroclazz]
        rowcs, colcs = col_row[:, 1], col_row[:, 0]
        #self.img[rowcs, colcs, :] = img_orig[rowcs, colcs, :]
        #self.imdata.set_data(self.img)
        scatcolors = self.scat.get_facecolors()
        scatcolors[:,0] = mmc.classvec
        scatcolors[:,1] = mmc.classvec
        scatcolors[:,2] = mmc.classvec
        self.scat.set_facecolors(scatcolors)
        
        if self.lasso.useblit:
            self.lasso.canvas.restore_region(self.lasso.background)
            self.lasso.ax.draw_artist(self.lasso.line)
            self.lasso.canvas.blit(self.lasso.ax.bbox)
        else:
            self.lasso.canvas.draw_idle()
        plt.draw()
        print_instructions()
    
    def disconnect(self):
        self.lasso.disconnect_events()
        self.canvas.draw_idle()