class SelectFromCollection(object):
    
    def __init__(self, ax, collection, mmc, img):
        self.colornormalizer = Normalize(vmin=0, vmax=1, clip=False)
        self.scat = plt.scatter(img[:, 0], img[:, 1], c=mmc.classvec)
        plt.gray()
        plt.setp(ax.get_yticklabels(), visible=False)
        ax.yaxis.set_tick_params(size=0)
        plt.setp(ax.get_xticklabels(), visible=False)
        ax.xaxis.set_tick_params(size=0)
        self.img = img
        self.canvas = ax.figure.canvas
        self.collection = collection
        #self.alpha_other = alpha_other
        self.mmc = mmc
        self.prevnewclazz = None

        self.xys = collection
        self.Npts = len(self.xys)
        
        self.lockedset = set([])

        self.lasso = LassoSelector(ax, onselect=self.onselect)#, lineprops = {:'prism'})
        self.lasso.disconnect_events()
        self.lasso.connect_event('button_press_event', self.lasso.onpress)
        self.lasso.connect_event('button_release_event', self.onrelease)
        self.lasso.connect_event('motion_notify_event', self.lasso.onmove)
        self.lasso.connect_event('draw_event', self.lasso.update_background)
        self.lasso.connect_event('key_press_event', self.onkeypressed)
        #self.lasso.connect_event('button_release_event', self.onrelease)
        self.ind = []
        self.slider_axis = plt.axes(slider_coords, visible = False)
        self.slider_axis2 = plt.axes(obj_fun_display_coords, visible = False)
        self.in_selection_slider = None
        newws = list(set(range(len(self.collection))) - self.lockedset)
        self.mmc.new_working_set(newws)
        self.lasso.line.set_visible(False)
    
    def onselect(self, verts):
        self.path = Path(verts)
        self.ind = np.nonzero(self.path.contains_points(self.xys))[0]
        print 'Selected '+str(len(self.ind))+' points'
        newws = list(set(self.ind) - self.lockedset)
        self.mmc.new_working_set(newws)
        self.redrawall()
    
    def onpress(self, event):
        if self.lasso.ignore(event) or event.inaxes != self.ax:
            return
        self.lasso.line.set_data([[], []])
        self.lasso.verts = [(event.xdata, event.ydata)]
        self.lasso.line.set_visible(True)

    def onrelease(self, event):
        if self.lasso.ignore(event):
            return
        if self.lasso.verts is not None:
            if event.inaxes == self.lasso.ax:
                self.lasso.verts.append((event.xdata, event.ydata))
            self.lasso.onselect(self.lasso.verts)
        self.lasso.verts = None
    
    def onkeypressed(self, event):
        print 'You pressed', event.key
        if event.key == '1':
            print 'Assigned all selected points to class 1'
            newclazz = 1
            mmc.claim_all_points_in_working_set(newclazz)
        if event.key == '0':
            print 'Assigned all selected points to class 0'
            newclazz = 0
            mmc.claim_all_points_in_working_set(newclazz)
        if event.key == 'a':
            print 'Selected all points'
            newws = list(set(range(len(self.collection))) - self.lockedset)
            self.mmc.new_working_set(newws)
            self.lasso.line.set_visible(False)
        if event.key == 'c':
            changecount = mmc.cyclic_descent_in_working_set()
            print 'Performed ', changecount, 'cyclic descent steps'
        if event.key == 'l':
            print 'Locked the class labels of selected points'
            self.lockedset = self.lockedset | set(self.ind)
            newws = list(set(self.ind) - self.lockedset)
            self.mmc.new_working_set(newws)
            #print newws
        if event.key == 'u':
            print 'Unlocked the selected points'
            self.lockedset = self.lockedset - set(self.ind)
            newws = list(set(self.ind) - self.lockedset)
            self.mmc.new_working_set(newws)
        self.redrawall()
    
    def redrawall(self, newslider = True):
        if newslider:
            nozeros = np.nonzero(self.mmc.classvec_ws)[0]
            self.slider_axis.cla()
            del self.slider_axis
            del self.slider_axis2
            self.slider_axis = plt.axes(slider_coords)
            self.slider_axis2 = plt.axes(obj_fun_display_coords)
            steepness_vector = mmc.compute_steepness_vector()
            X = [steepness_vector, steepness_vector]
            #right = left+width
            #self.slider_axis2.imshow(X, interpolation='bicubic', cmap=plt.get_cmap("Oranges"), alpha=1)
            self.slider_axis2.imshow(X, cmap=plt.get_cmap("Oranges"))
            self.slider_axis2.set_aspect('auto')
            plt.setp(self.slider_axis2.get_yticklabels(), visible=False)
            self.slider_axis2.yaxis.set_tick_params(size=0)
            del self.in_selection_slider
            self.in_selection_slider = None
            self.in_selection_slider = Slider(self.slider_axis, 'Fraction slider', 0., len(mmc.working_set), valinit=len(nozeros))
            def sliderupdate(val):
                val = int(val)
                nonzeroc = len(np.nonzero(self.mmc.classvec_ws)[0])
                if val > nonzeroc:
                    claims = val - nonzeroc
                    newclazz = 1
                elif val < nonzeroc:
                    claims = nonzeroc - val
                    newclazz = 0
                else: return
                print 'Claimed', claims, 'points for class', newclazz   #val, nonzeroc, claims
                self.claims = claims
                mmc.claim_n_points(claims, newclazz)
                steepness_vector = mmc.compute_steepness_vector()
                X = [steepness_vector, steepness_vector]
                self.slider_axis2.imshow(X, cmap=plt.get_cmap("Oranges"))
                self.slider_axis2.set_aspect('auto')
                self.redrawall(newslider = False) #HACK!
                self.prevnewclazz = newclazz
            self.in_selection_slider.on_changed(sliderupdate)
        oneclazz = np.nonzero(self.mmc.classvec)[0]
        col_row = self.collection[oneclazz]
        rowcs, colcs = col_row[:, 1], col_row[:, 0]
        #self.img[rowcs, colcs, :] = 0
        #self.img[rowcs, colcs, 0] = 255
        zeroclazz = np.nonzero(self.mmc.classvec - 1)[0]
        col_row = self.collection[zeroclazz]
        rowcs, colcs = col_row[:, 1], col_row[:, 0]
        #self.img[rowcs, colcs, :] = img_orig[rowcs, colcs, :]
        #self.imdata.set_data(self.img)
        scatcolors = self.scat.get_facecolors()
        scatcolors[:,0] = mmc.classvec
        scatcolors[:,1] = mmc.classvec
        scatcolors[:,2] = mmc.classvec
        self.scat.set_facecolors(scatcolors)
        
        if self.lasso.useblit:
            self.lasso.canvas.restore_region(self.lasso.background)
            self.lasso.ax.draw_artist(self.lasso.line)
            self.lasso.canvas.blit(self.lasso.ax.bbox)
        else:
            self.lasso.canvas.draw_idle()
        plt.draw()
        print_instructions()
    
    def disconnect(self):
        self.lasso.disconnect_events()
        self.canvas.draw_idle()
class SelectFromCollection(object):
    
    """Interactive RLS classifier interface for image segmentation

    Parameters
    ----------
    fig : matplotlib.figure.Figure
        The Figure object on which the interface is drawn.
        
    mmc : rlscore.learner.interactive_rls_classifier.InteractiveRlsClassifier
        Interactive RLS classifier object
        
    img : numpy.array
        Array consisting of image data
        
    collection : numpy.array, shape = [n_pixels, 2]
        array consisting of the (x,y) coordinates of all usable pixels in the image
    
    windowsize : int
        Determines the size of a window around grid points (2 * windowsize + 1) 
    """
    
    def __init__(self, fig, mmc, img, collection, windowsize = 0):
        
        #Initialize the main axis
        ax = fig.add_axes([0.1,0.1,0.8,0.8])
        ax.set_yticklabels([])
        ax.yaxis.set_tick_params(size = 0)
        ax.set_xticklabels([])
        ax.xaxis.set_tick_params(size = 0)
        self.imdata = ax.imshow(img)
        
        #Initialize LassoSelector on the main axis
        self.lasso = LassoSelector(ax, onselect = self.onselect)
        self.lasso.connect_event('key_press_event', self.onkeypressed)
        self.lasso.line.set_visible(False)
        
        self.mmc = mmc
        self.img = img
        self.img_orig = img.copy()
        self.collection = collection
        self.selectedset = set([])
        self.lockedset = set([])
        self.windowsize = windowsize
        
        #Initialize the fraction slider
        self.slider_axis = fig.add_axes([0.2, 0.06, 0.6, 0.02])
        self.in_selection_slider = Slider(self.slider_axis,
                                          'Fraction slider',
                                          0.,
                                          1,
                                          valinit = len(np.nonzero(self.mmc.classvec_ws)[0]) / len(mmc.working_set))
        def sliderupdate(val):
            val = int(val * len(mmc.working_set))
            nonzeroc = len(np.nonzero(self.mmc.classvec_ws)[0])
            if val > nonzeroc:
                claims = val - nonzeroc
                newclazz = 1
            elif val < nonzeroc:
                claims = nonzeroc - val
                newclazz = 0
            else: return
            print('Claimed', claims, 'points for class', newclazz)
            self.claims = claims
            mmc.claim_n_points(claims, newclazz)
            self.redrawall()
        self.in_selection_slider.on_changed(sliderupdate)
        
        #Initialize the display for the RLS objective funtion
        self.objfun_display_axis = fig.add_axes([0.1, 0.96, 0.8, 0.02])
        self.objfun_display_axis.imshow(mmc.compute_steepness_vector()[np.newaxis, :], cmap = plt.get_cmap("Oranges"))
        self.objfun_display_axis.set_aspect('auto')
        self.objfun_display_axis.set_yticklabels([])
        self.objfun_display_axis.yaxis.set_tick_params(size = 0)
    
    def onselect(self, verts):
        #Select a new working set
        self.path = Path(verts)
        self.selectedset = set(np.nonzero(self.path.contains_points(self.collection))[0])
        print('Selected ' + str(len(self.selectedset)) + ' points')
        newws = list(self.selectedset - self.lockedset)
        self.mmc.new_working_set(newws)
        self.redrawall()
    
    def onkeypressed(self, event):
        print('You pressed', event.key)
        if event.key == '1':
            print('Assigned all selected points to class 1')
            newclazz = 1
            mmc.claim_all_points_in_working_set(newclazz)
        if event.key == '0':
            print('Assigned all selected points to class 0')
            newclazz = 0
            mmc.claim_all_points_in_working_set(newclazz)
        if event.key == 'a':
            print('Selected all points')
            newws = list(set(range(len(self.collection))) - self.lockedset)
            self.mmc.new_working_set(newws)
            self.lasso.line.set_visible(False)
        if event.key == 'c':
            changecount = mmc.cyclic_descent_in_working_set()
            print('Performed ', changecount, 'cyclic descent steps')
        if event.key == 'l':
            print('Locked the class labels of selected points')
            self.lockedset = self.lockedset | self.selectedset
            newws = list(self.selectedset - self.lockedset)
            self.mmc.new_working_set(newws)
        if event.key == 'u':
            print('Unlocked the selected points')
            self.lockedset = self.lockedset - self.selectedset
            newws = list(self.selectedset - self.lockedset)
            self.mmc.new_working_set(newws)
        if event.key == 'p':
            print('Compute predictions and AUC on data')
            preds = self.mmc.predict(Xmat)
            print(auc(mmc.Y[:, 0], preds[:, 0]))
        self.redrawall()
    
    def redrawall(self):
        #Color all class one labeled pixels red 
        oneclazz = np.nonzero(self.mmc.classvec)[0]
        col_row = self.collection[oneclazz]
        rowcs, colcs = col_row[:, 1], col_row[:, 0]
        red = np.array([255, 0, 0])
        for i in range(-self.windowsize, self.windowsize + 1):
            for j in range(-self.windowsize, self.windowsize + 1):
                self.img[rowcs+i, colcs+j, :] = red
        
        #Return the original color of the class zero labeled pixels 
        zeroclazz = np.nonzero(self.mmc.classvec - 1)[0]
        col_row = self.collection[zeroclazz]
        rowcs, colcs = col_row[:, 1], col_row[:, 0]
        for i in range(-self.windowsize, self.windowsize + 1):
            for j in range(-self.windowsize, self.windowsize + 1):
                self.img[rowcs+i, colcs+j, :] = self.img_orig[rowcs+i, colcs+j, :]
        self.imdata.set_data(self.img)
        
        #Update the slider position according to labeling of the current working set
        sliderval = 0
        if len(mmc.working_set) > 0:
            sliderval = len(np.nonzero(self.mmc.classvec_ws)[0]) / len(mmc.working_set)
        self.in_selection_slider.set_val(sliderval)
        
        #Update the RLS objective function display
        self.objfun_display_axis.imshow(mmc.compute_steepness_vector()[np.newaxis, :], cmap=plt.get_cmap("Oranges"))
        self.objfun_display_axis.set_aspect('auto')
        
        #Final stuff
        self.lasso.canvas.draw_idle()
        plt.draw()
        print_instructions()