Ejemplo n.º 1
1
class ImagePanel(tk.Frame):
    """ This is a wrapper class for matplotlib figure canvas.
    It implements the mediator pattern for communication with collegues.
    """

    def __init__(self, parent, name, figsize=(2,2), dpi=100, title='', wl=False, info=False, toolbar=False, cb=False):
        self.name = name
        tk.Frame.__init__(self, parent)
        self.parent = parent
        self.img_is_set = False
        self.figsize = figsize
        self.dpi = dpi
        self.title = title
        self.toolbar = toolbar
        self.wl = wl
        self.make_canvas()
        if info: self.make_info()
        self.images = None
        self.original_img = None
        self.nZoom = 0
        self.indexOfImg = 0
        self.level_is_set = False
        self.window_is_set = False
        self.initial_level = None
        self.initial_window = None
        self.img = None
        self.cb = cb
        self.cbar = None
        self.metadata = {}
        self.collegues = []

    def register(self, collegue):
        self.collegues.append(collegue)

    def inform(self, event):
        for collegue in self.collegues:
            collegue.update_(self.name, event)

    def doubleclick(self, event):
        if event.dblclick:
            self.inform('<Double-Button-1>')
            print('DoubleClick')

    def make_canvas(self):
        self.canvas_frame = tk.Frame(self, padx=5, pady=5, cursor='crosshair')
        self.canvas_frame.pack(side=tk.TOP, fill=tk.BOTH, expand=1)

        if self.toolbar:
            self.toolbar_frame = self.make_toolbar()
            self.toolbar_frame.pack(side=tk.TOP, fill=tk.X, expand=0)

        self.f = Figure(figsize=self.figsize, dpi=self.dpi)
        self.subplot = self.f.add_subplot(111)
        self.subplot.set_title(self.title, fontsize=10)
        plt.setp(self.subplot.get_xticklabels(), fontsize=4)
        plt.setp(self.subplot.get_yticklabels(), fontsize=4)
        self.canvas = FigureCanvasTkAgg(self.f, master=self.canvas_frame)
        self.img = self.canvas.show()
        self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1)
        self.canvas.mpl_connect('scroll_event', self.mouseWheel)
        self.canvas.mpl_connect('button_press_event', self.doubleclick)

        if self.wl:
            self.wl_scale = self.make_wl()
            self.wl_scale.pack(side=tk.TOP, fill=tk.X, expand=0)

    def make_wl(self):
        f = font.Font(size=6)
        wl_frame = tk.Frame(self.canvas_frame)
        self.levelScale = tk.Scale(wl_frame, orient=tk.HORIZONTAL, from_=0.0, to=256.0, width=8, font=f, command=self.set_level)
        self.levelScale.pack(side=tk.TOP, fill=tk.X, expand=0)
        self.windowScale = tk.Scale(wl_frame, orient=tk.HORIZONTAL, from_=0.0, to=256.0, width=8, font=f, command=self.set_window)
        self.windowScale.pack(side=tk.TOP, fill=tk.X, expand=0)

        return wl_frame

    def make_info(self):
        self.info = tk.LabelFrame(self.parent, text='Image Info', padx=5, pady=5, width=400)
        self.info.pack(side=tk.RIGHT, fill=tk.BOTH, expand=0)

    def make_toolbar(self):
        toolbar_frame = tk.Frame(self.canvas_frame)

        set = ('DrawRectangle$icons/Rectangle$tk.FLAT$self.drawRectangle$tk.LEFT',
               'Delete$icons/Delete$tk.FLAT$self.deleteRectangle$tk.LEFT',
               'ZoomIn$icons/ZoomIn$tk.FLAT$self.zoomIn$tk.LEFT',
               'ZoomOut$icons/ZoomOut$tk.FLAT$self.zoomOut$tk.LEFT',
               'Reset$icons/ResetZoom$tk.FLAT$self.resetZoom$tk.LEFT',
               'Move$icons/Move$tk.FLAT$self.move$tk.LEFT',
               'Ruler$icons/Ruler$tk.FLAT$self.ruler$tk.LEFT',
               'Histogram$icons/Histogram$tk.FLAT$self.histogram$tk.LEFT',
               'Info$icons/Info$tk.FLAT$self.info$tk.LEFT',
               'Save$icons/Save18$tk.FLAT$self.savePicture$tk.LEFT'
               )
        self.imgToolbar= []
        for v in set:
            text, image, relief, command, side = v.split('$')
            self.imgToolbar.append(tk.PhotoImage(file=image+'.gif'))
            button = tk.Button(toolbar_frame, image=self.imgToolbar[-1], text=text, relief=eval(relief), command=eval(command))
            button.pack(side=tk.LEFT, fill=tk.BOTH, expand=0)
        return toolbar_frame

    def drawRectangle(self):
        print('Draw rectangle!')
        self.x0 = None
        self.y0 = None
        self.x1 = None
        self.y1 = None
        self.xp0 = None
        self.yp0 = None
        self.xp1 = None
        self.yp1 = None
        self.rectangle = Rectangle((0,0), 1, 1, facecolor='None', edgecolor='green')
        self.subplot.add_patch(self.rectangle)
        self.ispressed = False
        self.bpe = self.canvas.mpl_connect('button_press_event', self.drawRectangle_onPress)
        self.bre = self.canvas.mpl_connect('button_release_event', self.drawRectangle_onRelease)
        self.mne = self.canvas.mpl_connect('motion_notify_event', self.drawRectangle_onMotion)

    def drawRectangle_onPress(self, event):
        self.xp0 = event.x
        self.yp0 = event.y
        self.x0 = event.xdata
        self.y0 = event.ydata
        self.x1 = event.xdata
        self.y1 = event.ydata
        self.rectangle.set_width(self.x1-self.x0)
        self.rectangle.set_xy((self.x0, self.y0))
        self.rectangle.set_linestyle('dashed')
        self.canvas.draw()
        self.ispressed = True


    def drawRectangle_onRelease(self, event):
        self.xp1 = event.x
        self.yp1 = event.y
        self.x1 = event.xdata
        self.y1 = event.ydata
        self.rectangle.set_width(self.x1-self.x0)
        self.rectangle.set_height(self.y1-self.y0)
        self.rectangle.set_xy((self.x0, self.y0))
        self.rectangle.set_linestyle('solid')
        self.canvas.draw()
        self.ispressed = False
        self.canvas.mpl_disconnect(self.bpe)
        self.canvas.mpl_disconnect(self.bre)
        self.canvas.mpl_disconnect(self.mne)
        print(self.xp0, self.yp0, self.xp1, self.yp1)
        self.inform('<DrawRectangle>')
        return (self.xp0, self.yp0, self.xp1, self.yp1)

    def getRectanglePoints(self):
        return (self.xp0, self.yp0, self.xp1, self.yp1)

    def drawRectangle_onMotion(self, event):
        if self.ispressed is True:
            self.x1 = event.xdata
            self.y1 = event.ydata
            self.rectangle.set_width(self.x1-self.x0)
            self.rectangle.set_height(self.y1-self.y0)
            self.rectangle.set_xy((self.x0, self.y0))
            self.rectangle.set_linestyle('dashed')
            self.canvas.draw()

    def deleteRectangle(self):
        print('Delete rectangle!')
        self.rectangle.remove()
        self.canvas.draw()
        self.inform('<DeleteRectangle>')

    def zoomIn(self):
        print('Zoom in!')
        print(np.shape(self.images))
        self.images = self.images[:,10:-10, 10:-10]
        self.show_images()
        self.nZoom = self.nZoom+1


    def zoomOut(self):
        print('ZoomOut!')
        if np.shape(self.images) != np.shape(self.original_img):
            if self.nZoom>1:
                self.images = self.original_img[:,(self.nZoom-1)*10:-(self.nZoom-1)*10, (self.nZoom-1)*10:-(self.nZoom-1)*10]
                self.show_images()
                self.nZoom = self.nZoom-1
            else:
                self.images = self.original_img
                self.nZoom = 0



    def resetZoom(self):
        print('Reset zoom!')
        self.images = self.original_img
        self.show_images()
        self.nZoom = 0

    def histogram(self):
        print('Histogram!')
        histo = tk.Toplevel()
        f = Figure(figsize=(4,4), dpi=100)
        subplot = f.add_subplot(111)
        subplot.set_title('Histogram', fontsize=10)
        plt.setp(subplot.get_xticklabels(), fontsize=8)
        plt.setp(subplot.get_yticklabels(), fontsize=8)
        canvas = FigureCanvasTkAgg(f, master=histo)
        canvas.show()
        canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1)
        subplot.hist(self.images.flatten(), 100, normed=True, histtype='step', fc='k', ec='k')
        canvas.draw()
        close_button = tk.Button(histo, text='Close', command = histo.destroy)
        close_button.pack(side=tk.TOP)

    def savePicture(self):
        print('Save!')
        savefile = tk.filedialog.asksaveasfilename(title='Save image as ...', defaultextension='png',
                                                  filetypes=[('all files', '.*'), ('png files', '.png')])
        if savefile:
            self.f.savefig(savefile, dpi=1200, format='png', )

    def mouseWheel(self, event):
        print('Test mouseWheel.')
        if event.button == 'down':
            self.indexOfImg = self.indexOfImg+1
            self.show_images()
            self.inform('<MouseWheelDown')
        if event.button == 'up':
            self.indexOfImg = self.indexOfImg-1
            self.show_images()
            self.inform('<MouseWheelUp')

    def change_wl(self, arg):
        print('Button press event test')

    def set_images(self, images):
        self.images = images
        if self.img_is_set == False:
            self.original_img = images
        self.show_images()
        self.img_is_set = True

    def set_metadata(self, data):
        self.metadata = data

    def show_images(self):
        plt.clf()
        if (self.indexOfImg < np.size(self.images, 0))and (self.indexOfImg >= 0):
            self.img = self.subplot.imshow(self.images[self.indexOfImg])
            if self.level_is_set == False:
                self.level = (np.max(self.images[self.indexOfImg])-np.min(self.images[self.indexOfImg]))/2
                self.window = 2*self.level
                self.levelScale.config(from_=np.min(self.images[self.indexOfImg]))
                self.levelScale.config(to=np.max(self.images[self.indexOfImg]))
                self.levelScale.set(self.level)
                self.windowScale.config(from_=0)
                self.windowScale.config(to=self.window)
                self.windowScale.set(self.window)
                if self.window < 1:
                    self.levelScale.config(resolution=0.0001)
                    self.windowScale.config(resolution=0.0001)
                self.img.set_clim(self.level-self.window/2, self.level+self.window/2)
        else:
            self.indexOfImg = 0
            self.img = self.subplot.imshow(self.images[self.indexOfImg])

            if self.level_is_set == False:
                self.level = (np.max(self.images[self.indexOfImg])-np.min(self.images[self.indexOfImg]))/2
                self.window = 2*self.level
                self.levelScale.config(from_=np.min(self.images[self.indexOfImg]))
                self.levelScale.config(to=np.max(self.images[self.indexOfImg]))
                self.levelScale.set(self.level)
                self.windowScale.config(from_=0)
                self.windowScale.config(to=self.window)
                self.windowScale.set(self.window)
                if self.window < 1:
                    self.levelScale.config(resolution=0.0001)
                    self.windowScale.config(resolution=0.0001)
                self.img.set_clim(self.level-self.window/2, self.level+self.window/2)

        if self.cb:
                if self.cbar:
                    self.cbar.set_clim(np.min(self.images[self.indexOfImg]), np.max(self.images[self.indexOfImg]))
                else:
                    self.cbar = self.f.colorbar(self.img)
        self.canvas.draw()

    def set_level(self, event):
        self.level = self.levelScale.get()
        if self.level >= (np.max(self.images[self.indexOfImg])-np.min(self.images[self.indexOfImg]))/2:
             self.window = 2*(np.max(self.images[self.indexOfImg])-self.level)
        elif self.level < (np.max(self.images[self.indexOfImg])-np.min(self.images[self.indexOfImg]))/2:
             self.window = 2*(self.level-np.min(self.images[self.indexOfImg]))
        print(self.level, self.window)
        if self.windowScale.get() <= self.window:
             self.img.set_clim(float(self.level-self.windowScale.get()/2), float(self.level+self.windowScale.get()/2))
        else:
             self.img.set_clim(float(self.level-self.window/2), float(self.level+self.window/2))
        #self.level_is_set = True
        self.windowScale.config(to = self.window)
        self.canvas.draw()

    def set_window(self, event):
        self.window = self.windowScale.get()
        self.img.set_clim(float(self.level-self.window/2), float(self.level+self.window/2))
        #self.level_is_set = True
        self.canvas.draw()



    def set_indexOfImg(self, index):
        self.indexOfImg = index

    def ruler(self):
        print('Measure!')

    def move(self):
        print('Move!')

    def info(self):
        info = tk.Toplevel()
        tk.Button(info, text='Close', command = info.destroy).pack(side=tk.TOP)
Ejemplo n.º 2
0
class click_yrange:
   '''An interactive yrange selector.  Given an axis and a starting
   y0 location, draw a full-width rectange that follows the mouise.
   Similar to click_window, but more appropriate for selecting out
   a y-range.'''

   def __init__(self, ax, y0):
      self.ax = ax
      self.y0 = y0
      x0,x1 = ax.get_xbound()
      self.rect = Rectangle((x0,y0), width=(x1-x0), height=0, alpha=0.1)
      ax.add_artist(self.rect)

   def connect(self):
      self.cidmotion = self.rect.figure.canvas.mpl_connect(
            'motion_notify_event', self.on_motion)

   def on_motion(self, event):
      # Have we left the axes?
      if event.inaxes != self.rect.axes:  return

      self.rect.set_height(event.ydata - self.y0)
      self.ax.figure.canvas.draw()

   def close(self):
      self.rect.figure.canvas.mpl_disconnect(self.cidmotion)
      self.rect.remove()
      self.ax.figure.canvas.draw()
      return(self.y0, self.rect.get_y()+self.rect.get_height())
Ejemplo n.º 3
0
class click_window:
    '''An interactive window.  Given an axis instance and a start point
   (x0,y0), draw a dynamic rectangle that follows the mouse until
   the close() function is called (which returns the coordinates of
   the final rectangle.  Useful or selecting out square regions.'''
    def __init__(self, ax, x0, y0):
        self.ax = ax
        self.x0 = x0
        self.y0 = y0
        self.rect = Rectangle((x0, y0), width=0, height=0, alpha=0.1)
        ax.add_artist(self.rect)

    def connect(self):
        self.cidmotion = self.rect.figure.canvas.mpl_connect(
            'motion_notify_event', self.on_motion)

    def on_motion(self, event):
        # Have we left the axes?
        if event.inaxes != self.rect.axes: return

        self.rect.set_width(event.xdata - self.x0)
        self.rect.set_height(event.ydata - self.y0)
        self.ax.figure.canvas.draw()

    def close(self):
        self.rect.figure.canvas.mpl_disconnect(self.cidmotion)
        extent = self.rect.get_bbox().get_points()
        self.rect.remove()
        self.ax.figure.canvas.draw()
        return (list(ravel(extent)))
Ejemplo n.º 4
0
class click_yrange:
    '''An interactive yrange selector.  Given an axis and a starting
   y0 location, draw a full-width rectange that follows the mouise.
   Similar to click_window, but more appropriate for selecting out
   a y-range.'''
    def __init__(self, ax, y0):
        self.ax = ax
        self.y0 = y0
        x0, x1 = ax.get_xbound()
        self.rect = Rectangle((x0, y0), width=(x1 - x0), height=0, alpha=0.1)
        ax.add_artist(self.rect)

    def connect(self):
        self.cidmotion = self.rect.figure.canvas.mpl_connect(
            'motion_notify_event', self.on_motion)

    def on_motion(self, event):
        # Have we left the axes?
        if event.inaxes != self.rect.axes: return

        self.rect.set_height(event.ydata - self.y0)
        self.ax.figure.canvas.draw()

    def close(self):
        self.rect.figure.canvas.mpl_disconnect(self.cidmotion)
        self.rect.remove()
        self.ax.figure.canvas.draw()
        return (self.y0, self.rect.get_y() + self.rect.get_height())
Ejemplo n.º 5
0
class click_window:
   '''An interactive window.  Given an axis instance and a start point
   (x0,y0), draw a dynamic rectangle that follows the mouse until
   the close() function is called (which returns the coordinates of
   the final rectangle.  Useful or selecting out square regions.'''

   def __init__(self, ax, x0, y0):
      self.ax = ax
      self.x0 = x0
      self.y0 = y0
      self.rect = Rectangle((x0,y0), width=0, height=0, alpha=0.1)
      ax.add_artist(self.rect)

   def connect(self):
      self.cidmotion = self.rect.figure.canvas.mpl_connect(
            'motion_notify_event', self.on_motion)

   def on_motion(self, event):
      # Have we left the axes?
      if event.inaxes != self.rect.axes:  return

      self.rect.set_width(event.xdata - self.x0)
      self.rect.set_height(event.ydata - self.y0)
      self.ax.figure.canvas.draw()

   def close(self):
      self.rect.figure.canvas.mpl_disconnect(self.cidmotion)
      extent = self.rect.get_bbox().get_points()
      self.rect.remove()
      self.ax.figure.canvas.draw()
      return(list(ravel(extent)))
Ejemplo n.º 6
0
class Layer(object):
    def __init__(self,n,y0,yf,nprev = 1,nnext = 1,dndlambda=1e-3):
        self.n = n
        self.nprev = nprev
        self.nnext = nnext
        self.color = (1./(n**2),1./(n**2),1./np.sqrt(n))
        self._artist = Rectangle((-1,yf),2,y0-yf,fill=True,facecolor=self.color)
        self.y0 = y0
        self.yf = yf
        #per nm
        self.dndlambda = dndlambda


    def set_master(self,master):
        self.master = master
        self.master.axes.add_artist(self._artist)

    def contains(self,y):
        return y >= self.yf and y < self.y0

    def ns_for_lambda(self,lambda_):
        def new_n_pos(n):
            return max(1,n+(lambda_-LAMBDA0)*self.dndlambda)
        def new_n_neg(n):
            return max(1,n+(LAMBDAf-lambda_)*-self.dndlambda)
        
        new_n = new_n_pos if self.dndlambda > 0 else new_n_neg
        n = 1 if self.n == 1 else new_n(self.n)
        nprev = 1 if self.nprev == 1 else new_n(self.nprev) 
        nnext = 1 if self.nnext == 1 else new_n(self.nnext)
        return n,nprev,nnext

    def remove(self):
        self._artist.remove()
    def _on_figure_motion(self, event):
        if event.inaxes == None:
            return

        self._motion_wait -= 1
        if self._motion_wait > 0:
            return

        x0, y0, x1, y1 = event.inaxes.dataLim.bounds

        number_of_points = len(event.inaxes.lines[-1].get_ydata())
        index = int(round((number_of_points-1) * (event.xdata-x0)/x1))

        if len(self._data[0]) < index + 1:
            return

        if self._background is None:
            self._background = self._figure.canvas.copy_from_bbox(self._graphs[0].bbox)

        # restore the clean slate background
        self._figure.canvas.restore_region(self._background)

        polygon = None
        if self._select_start == None:
            linev = self._graphs[-1].axvline(x=event.xdata, linewidth=1, color="#000000", alpha=0.5)
            lineh = self._graphs[7].axhline(y=event.ydata, linewidth=1, color="#000000", alpha=0.5)
            self._graphs[-1].draw_artist(linev)
            self._graphs[-1].draw_artist(lineh)
        else:
            width = abs(event.xdata - self._select_start)
            start = self._select_start
            if (event.xdata < start):
                start = event.xdata
            if width < 20:
                col = "#aa4444"
            else:
                col = "#888888"
            polygon = Rectangle((start, 0), width, y1 + 10000, facecolor=col, alpha=0.5)
            self._graphs[-1].add_patch(polygon)
            self._graphs[-1].draw_artist(polygon)

        self._figure.canvas.blit(self._graphs[-1].bbox)
        if self._select_start == None:
            linev.remove()
            lineh.remove()
        if polygon != None:
            polygon.remove()

        for i in xrange(0, 8):
            if (i < 2):
                val = str(self._data[i][index])
            else:
                val = str(int(self._data[i][index]))                    
            self._mouse_texts[i].setText(val)
Ejemplo n.º 8
0
class RectEditor(object):

    def __init__(self, fig, axis, rectangle):
        self.fig = fig
        self.axis = axis
        self.rectangle = rectangle
        self.x0 = 0
        self.y0 = 0
        self.x1 = 0
        self.y1 = 0
        self.done = 0
        self.cidpress = None
        self.cidrelease = None
        self.cidclose = None

    def connect(self):
        self.cidpress = self.fig.canvas.mpl_connect('button_press_event', self.on_press)
        self.cidrelease = self.fig.canvas.mpl_connect('button_release_event', self.on_release)
        self.cidclose = self.fig.canvas.mpl_connect('close_event', self.window_closed)

    def on_press(self, event):
        if event.inaxes == self.axis:
            self.done = 1
            self.x0, self.y0 = event.xdata, event.ydata
        else:
            return

    def on_release(self, event):
        if event.inaxes == self.axis:
            self.x1, self.y1 = event.xdata, event.ydata
            self.rectangle.remove()
            self.rectangle = Rectangle((min(self.x0, self.x1), min(self.y0, self.y1)),
                                       abs(self.x1-self.x0), abs(self.y1-self.y0), color='g', linewidth=3, fill=False)
            self.axis.add_artist(self.rectangle)
            self.fig.canvas.draw()
        else:
            return

    def window_closed(self, event):
        if event.canvas.figure == self.fig:
            self.fig.canvas.mpl_disconnect(self.cidpress)
            self.fig.canvas.mpl_disconnect(self.cidrelease)
            self.fig.canvas.mpl_disconnect(self.cidclose)
            self.fig = None

    def remove_rectangle(self):
        self.rectangle.remove()
        self.fig.canvas.draw()

    def create_rectangle(self, rec):
        self.rectangle = Rectangle(rec.get_xy(), rec.get_width(), rec.get_height(), color='g', linewidth=3, fill=False)
        self.axis.add_artist(self.rectangle)
        self.fig.canvas.draw()
Ejemplo n.º 9
0
    def plot_window(self, component, starttime, endtime, window_weight):
        if component == "Z":
            axis = self.plot_axis_z
        elif component == "N":
            axis = self.plot_axis_n
        elif component == "E":
            axis = self.plot_axis_e
        else:
            raise NotImplementedError

        trace = self.data["synthetics"][0]

        ymin, ymax = axis.get_ylim()
        xmin = starttime - trace.stats.starttime
        width = endtime - starttime
        height = ymax - ymin
        rect = Rectangle((xmin, ymin), width, height, facecolor="0.6",
                         alpha=0.5, edgecolor="0.5", picker=True)
        axis.add_patch(rect)
        attached_text = axis.text(
            x=xmin + 0.02 * width, y=ymax - 0.02 * height,
            s=str(window_weight), verticalalignment="top",
            horizontalalignment="left", color="0.4", weight=1000)

        # Monkey patch to trigger text removal as soon as the rectangle is
        # removed.
        def remove():
            super(Rectangle, rect).remove()
            attached_text.remove()
        rect.remove = remove
Ejemplo n.º 10
0
def UpdateZoomGizmo(scale, xrange, yrange):
    global axis, zoom_factor, img_offset, z_container, z_box

    aspect = yrange / xrange

    # Change the size of the Gizmo
    size = 320

    if zoom_factor > 1:
        gizmo_w = size * scale
        gizmo_h = size * scale * aspect
        gizmo_pos = img_offset[0] - xrange * scale, img_offset[
            1] + yrange * scale - gizmo_h

        if z_container is None:
            z_container = Rectangle(gizmo_pos,
                                    gizmo_w,
                                    gizmo_h,
                                    edgecolor="w",
                                    facecolor='none')
            z_container.label = "zoom_container"

            z_box = Rectangle(gizmo_pos, gizmo_w, gizmo_h, alpha=0.5)
            z_box.label = "zoom_box"

            axis.add_artist(z_container)
            axis.add_artist(z_box)
        else:
            z_container.set_xy(gizmo_pos)
            z_container.set_width(gizmo_w)
            z_container.set_height(gizmo_h)

            z_box.set_x(gizmo_pos[0] + 0.5 *
                        (img_offset[0] * gizmo_w / xrange - gizmo_w * scale))
            z_box.set_y(gizmo_pos[1] + 0.5 *
                        (img_offset[1] * gizmo_h / yrange - gizmo_h * scale))
            z_box.set_width(gizmo_w * scale)
            z_box.set_height(gizmo_h * scale)
    else:
        if z_container is not None:
            z_container.remove()
            z_container = None

            z_box.remove()
            z_box = None
Ejemplo n.º 11
0
def onclick(event, ax, text_count, imgno, imgsetx, gw, gh, nw, nh):
    global cimgs
    xst = int(event.xdata - (nh / 2))
    yst = int(event.ydata - (nw / 2))
    if xst < 0:
        xst = 0
    if yst < 0:
        yst = 0
    if xst <= (gw - nw) and yst <= (gh - nh):
        #display rectangle on image to show selection for downsampled image
        rect = Rectangle((xst, yst), nw, nh, alpha=0.3)
        ax.add_patch(rect)
        ax.figure.canvas.draw()
        imgpoint = [xst, yst]
        imgsetx.append(imgpoint)
        cimgs = cimgs + 1
        rect.remove()
        text_count.set_text("count= %d  len= %d " % (cimgs, imgno))
Ejemplo n.º 12
0
class ItemArtist:
    def __init__(self, position, state):
        self.position = position

        indx = state.positions.index(position)

        self.top = -state.tops[indx]
        self.top_line, = pylab.plot([0, width], 2 * [self.top], c='b')

        self.bottom = -state.bottoms[indx]
        self.bottom_line, = pylab.plot([0, width], 2 * [self.bottom], c='b')

        self.edge = -state.edges[indx]
        self.edge_line, = pylab.plot([0, width], 2 * [self.edge], c='g')

        self.label = Text(width / 2, (self.top + self.bottom) / 2,
                          str(position),
                          va='center',
                          ha='center')

        self.axes = pylab.gca()
        self.axes.add_artist(self.label)

        self.src_box = None
        self.exp_box = None
        self._check_boxes(state)

    def _check_boxes(self, state):

        if self.position == state.src:
            if self.src_box == None:
                self.src_box = Rectangle((0, self.bottom),
                                         width,
                                         self.top - self.bottom,
                                         fill=True,
                                         ec=None,
                                         fc='0.7')
                self.axes.add_patch(self.src_box)
            else:
                self.src_box.set_y(self.bottom)
                self.src_box.set_height(self.top - self.bottom)

        elif self.position == state.exp1:
            if state.exp1 < state.src:
                gap_bottom = self.top - state.exp1_gap
            else:
                gap_bottom = self.bottom

            if self.exp_box == None:
                self.exp_box = Rectangle((0, gap_bottom),
                                         width,
                                         state.exp1_gap,
                                         fill=True,
                                         ec=None,
                                         fc='0.7')
                self.axes.add_patch(self.exp_box)
            else:
                self.exp_box.set_y(gap_bottom)
                self.exp_box.set_height(state.exp1_gap)

        elif self.position == state.exp2:
            if state.exp2 < state.src:
                gap_bottom = self.top - state.exp2_gap
            else:
                gap_bottom = self.bottom

            if self.exp_box == None:
                self.exp_box = Rectangle((0, gap_bottom),
                                         width,
                                         state.exp2_gap,
                                         fill=True,
                                         ec=None,
                                         fc='0.7')
                self.axes.add_patch(self.exp_box)
            else:
                self.exp_box.set_y(gap_bottom)
                self.exp_box.set_height(state.exp2_gap)
        else:
            if self.src_box != None:
                self.src_box.remove()
                self.src_box = None
            if self.exp_box != None:
                self.exp_box.remove()
                self.exp_box = None

    def inState(self, state):
        return self.position in state.positions

    def update(self, position, state):
        moved = False

        if position != self.position:
            self.position = position
            self.label.set_text(str(position))

        indx = state.positions.index(self.position)

        old_top = self.top
        self.top = -state.tops[indx]
        if old_top != self.top:
            self.top_line.set_ydata(2 * [self.top])
            moved = True

        old_bottom = self.bottom
        self.bottom = -state.bottoms[indx]
        if old_bottom != self.bottom:
            self.bottom_line.set_ydata(2 * [self.bottom])
            moved = True

        old_edge = self.edge
        self.edge = -state.edges[indx]
        if old_edge != self.edge:
            self.edge_line.set_ydata(2 * [self.edge])

        if moved:
            # adjust label, blank spot, etc.
            self.label.set_y((self.top + self.bottom) / 2)
            self._check_boxes(state)

    def remove(self):
        self.edge_line.remove()
        self.top_line.remove()
        self.bottom_line.remove()
        self.label.remove()

        if self.src_box != None:
            self.src_box.remove()
        if self.exp_box != None:
            self.exp_box.remove()
Ejemplo n.º 13
0
class EarthPlot(FigureCanvas):
    """This class is the core of the appliccation. It display the Earth
    map, the patterns, stations, elevation contour, etc.
    """

    # EarthPlot constructor
    def __init__(self,
                 parent=None,
                 width=5,
                 height=5,
                 dpi=300,
                 proj='nsper',
                 res='crude',
                 config=None):
        utils.trace('in')

        # Store Canvas properties
        self._plot_title = 'Default Title'
        self._width = width
        self._height = height
        self._dpi = dpi
        self._centralwidget = parent
        self._app = self._centralwidget.parent()
        # store _earth_map properties
        self._projection = proj
        self._resolution = res

        # define figure in canvas
        self._figure = Figure(figsize=(self._width, self._height),
                              dpi=self._dpi)
        self._axes = self._figure.add_subplot(111)

        FigureCanvas.__init__(self, self._figure)
        self.setParent(self._centralwidget)
        self._app = self.parent().parent()
        FigureCanvas.setSizePolicy(self, QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)
        self.setFocusPolicy(QtCore.Qt.ClickFocus)
        self.setFocus()

        # initialize EarthPlot fields
        self._patterns = {}
        self._elev = {}
        self._stations = []
        self._polygons = []
        self._clrbar = None
        self._clrbar_axes = None
        self._earth_map = None
        self._coastlines_col = None
        self._countries_col = None
        self._parallels_col = None
        self._meridians_col = None

        # initialize PPlot limits
        self.llcrnrx = None
        self.llcrnry = None
        self.urcrnrx = None
        self.urcrnry = None
        self.llcrnrlon = None
        self.llcrnrlat = None
        self.urcrnrlon = None
        self.urcrnrlat = None
        self.centerx = None
        self.centery = None
        self.cntrlon = None
        self.cntrlat = None

        # if a config has been provided by caller
        if config:
            # get font size with fallback = 5
            fontsize = config.getint('DEFAULT', 'font size', fallback=5)
            # set default font size
            plt.rcParams.update({'font.size': fontsize})
            self._axes.xaxis.label.set_fontsize(fontsize)
            self._axes.yaxis.label.set_fontsize(fontsize)

            # set map resolution (take only first letter in lower case)
            self._resolution = config.get('DEFAULT',
                                          'map resolution',
                                          fallback=self._resolution).lower()
            self._app.getmenuitem(item='View>Map resolution>{res}'.format(
                res=self._resolution)).setChecked(True)
            self._resolution = self._resolution[0]
            self._projection = config.get('DEFAULT',
                                          'projection',
                                          fallback='nsper')
            if self._projection == 'nsper':
                self._app.getmenuitem(
                    item='View>Projection>Geo').setChecked(True)
            elif self._projection == 'cyl':
                self._app.getmenuitem(
                    item='View>Projection>Cylindrical').setChecked(True)

            # get point of view coordinates if defined
            longitude = config.getfloat('VIEWER', 'longitude', fallback=0.0)
            latitude = config.getfloat('VIEWER', 'latitude', fallback=0.0)
            altitude = config.getfloat('VIEWER',
                                       'altitude',
                                       fallback=cst.ALTGEO)

            # get Earth plot configuration
            self._bluemarble = config.getboolean('DEFAULT',
                                                 'blue marble',
                                                 fallback=False)
            self._app.getmenuitem(item='View>Blue Marble').setChecked(
                self._bluemarble)
            self._coastlines = config.get('DEFAULT',
                                          'coast lines',
                                          fallback='light')
            self._app.getmenuitem(item='View>Coast lines>' +
                                  self._coastlines).setChecked(True)
            self._countries = config.get('DEFAULT',
                                         'countries',
                                         fallback='light')
            self._app.getmenuitem(item='View>Country borders>{cntry}'.format(
                cntry=self._countries)).setChecked(True)
            self._parallels = config.get('DEFAULT',
                                         'parallels',
                                         fallback='light')
            self._app.getmenuitem(item='View>Parallels>' +
                                  self._parallels).setChecked(True)
            self._meridians = config.get('DEFAULT',
                                         'meridians',
                                         fallback='light')
            self._app.getmenuitem(item='View>Meridians>' +
                                  self._meridians).setChecked(True)

            # initialize angle of view
            # Satellite Longitude, latitude and altitude
            self._viewer = Viewer(lon=longitude, lat=latitude, alt=altitude)

            # get default directory
            self.rootdir = config.get('DEFAULT', 'root', fallback='C:\\')

            # Initialize zoom
            self._zoom = Zoom(self._projection)
            self._zoom.min_azimuth = config.getfloat('GEO',
                                                     'min azimuth',
                                                     fallback=-9)
            self._zoom.min_elevation = config.getfloat('GEO',
                                                       'min elevation',
                                                       fallback=-9)
            self._zoom.max_azimuth = config.getfloat('GEO',
                                                     'max azimuth',
                                                     fallback=9)
            self._zoom.max_elevation = config.getfloat('GEO',
                                                       'max elevation',
                                                       fallback=9)
            self._zoom.min_longitude = config.getfloat('CYLINDRICAL',
                                                       'min longitude',
                                                       fallback=-180)
            self._zoom.min_latitude = config.getfloat('CYLINDRICAL',
                                                      'min latitude',
                                                      fallback=-90)
            self._zoom.max_longitude = config.getfloat('CYLINDRICAL',
                                                       'max longitude',
                                                       fallback=180)
            self._zoom.max_latitude = config.getfloat('CYLINDRICAL',
                                                      'max latitude',
                                                      fallback=90)
            pattern_index = 1
            pattern_section = 'PATTERN' + str(pattern_index)
            while pattern_section in config:
                if 'file' in config[pattern_section]:
                    conf = {}
                    conf['filename'] = config.get(pattern_section, 'file')
                    conf['sat_lon'] = config.getfloat(pattern_section,
                                                      'longitude',
                                                      fallback=0.0)
                    conf['sat_lat'] = config.getfloat(pattern_section,
                                                      'latitude',
                                                      fallback=0.0)
                    conf['sat_alt'] = config.getfloat(pattern_section,
                                                      'altitude',
                                                      fallback=cst.ALTGEO)
                    conf['sat_yaw'] = config.getfloat(pattern_section,
                                                      'yaw',
                                                      fallback=0.0)
                    conf['title'] = config.get(pattern_section,
                                               'title',
                                               fallback='Default title')
                    conf['level'] = config.get(pattern_section,
                                               'level',
                                               fallback='25, 30, 35, 38, 40')
                    conf['revert_x'] = config.getboolean(pattern_section,
                                                         'revert x-axis',
                                                         fallback=False)
                    conf['revert_y'] = config.getboolean(pattern_section,
                                                         'revert y-axis',
                                                         fallback=False)
                    conf['rotate'] = config.getboolean(pattern_section,
                                                       'rotate',
                                                       fallback=False)
                    conf['use_second_pol'] = \
                        config.getboolean(pattern_section,
                                          'second polarisation',
                                          fallback=False)
                    conf['display_slope'] = \
                        config.getboolean(pattern_section,
                                          'slope',
                                          fallback=False)
                    conf['shrink'] = config.getboolean(pattern_section,
                                                       'shrink',
                                                       fallback=False)
                    conf['azshrink'] = config.getfloat(pattern_section,
                                                       'azimuth shrink',
                                                       fallback=0.0)
                    conf['elshrink'] = config.getfloat(pattern_section,
                                                       'elevation shrink',
                                                       fallback=0.0)
                    conf['offset'] = config.getboolean(pattern_section,
                                                       'offset',
                                                       fallback=False)
                    conf['azeloffset'] = config.getboolean(pattern_section,
                                                           'azeloffset',
                                                           fallback=True)
                    conf['azoffset'] = config.getfloat(pattern_section,
                                                       'azimuth offset',
                                                       fallback=0.0)
                    conf['eloffset'] = config.getfloat(pattern_section,
                                                       'elevation offset',
                                                       fallback=0.0)
                    conf['cf'] = config.getfloat(pattern_section,
                                                 'conversion factor',
                                                 fallback=0.0)
                    conf['linestyles'] = config.get(pattern_section,
                                                    'linestyles',
                                                    fallback='solid')
                    conf['linewidths'] = \
                        cst.BOLDNESS[config.get(pattern_section,
                                                'linewidths',
                                                fallback='medium')]
                    conf['isolevel'] = [
                        float(s) for s in conf['level'].split(',')
                    ]
                    conf['Color surface'] = config.getboolean(pattern_section,
                                                              'Color surface',
                                                              fallback=False)

                    self.loadpattern(conf=conf)

                    self.settitle(conf['title'])

                    # check for next pattern
                    pattern_index += 1
                    pattern_section = 'PATTERN' + str(pattern_index)
            # add stations from ini file
            station_index = 1
            station_section = 'STATION' + str(station_index)
            while station_section in config:
                # load stations from sta file
                if 'file' in config[station_section]:
                    station_file = config.get(station_section, 'file')
                    self._stations.extend(
                        stn.get_station_from_file(station_file, self))
                # load station from description in ini file
                elif 'name' in config[station_section]:
                    station = stn.Station(parent=self)
                    station.configure(config._sections[station_section])
                    stncontroller = stn.StationControler(parent=self,
                                                         station=station)
                    self._stations.append(stncontroller)
                # check for next station section
                station_index += 1
                station_section = 'STATION' + str(station_index)
            # add elevation contour from ini file
            elevation_index = 1
            elevation_section = 'ELEVATION' + str(elevation_index)
            while elevation_section in config:
                # load stations from sta file
                elevationlist = [
                    float(s) for s in config._sections[elevation_section]
                    ['elevation'].split(',')
                ]
                for elevation_value in elevationlist:
                    conf = config._sections[elevation_section]
                    conf['elevation'] = elevation_value
                    elevation = elv.Elevation(parent=self)
                    elevation.configure(conf)
                    self._elev['Elev[' + str(elevation_value) +
                               ']'] = elevation
                # check for next station section
                elevation_index += 1
                elevation_section = 'ELEVATION' + str(elevation_index)

        # initialise reference to Blue Marble
        self._bluemarble_imshow = None

        # default file name to save figure
        self.filename = 'plot.PNG'

        # connect canvas to mouse event (enable zoom and recenter)
        self.zoomposorigin = None
        self.zoomposfinal = None
        self.zoompatch = None
        self.dragorigin = None
        # detect motion to update zoom rectangle
        self.mpl_connect('motion_notify_event', self.mouse_move)
        # detect mouse press to recenter or initiate drag and zoom
        self.mpl_connect('button_press_event', self.mouse_press)
        # detect mouse button release to finalize zoom
        self.mpl_connect('button_release_event', self.mouse_release)
        # detect keyboard kkey press for shortcut
        self.mpl_connect('key_press_event', self.key_press)

        # draw the already loaded elements
        self.draw_elements()

        utils.trace('out')

    # End of EarthPlot constructor

    def mouse_move(self, event):
        """Set mouse longitude and latitude plus directivity in the status bar.
        """
        # get coordinates of the mouse motion event
        xevent = event.x
        yevent = event.y
        bbox = event.canvas.figure.axes[0].bbox
        # compute longitude and latitude from the bbox of the event
        mouselon, mouselat = self.get_mouse_ll(xevent, yevent, bbox)
        if mouselon > cst.MAX_LON or mouselon < cst.MIN_LON:
            mouselon = np.nan
        if mouselat > cst.MAX_LAT or mouselat < cst.MIN_LAT:
            mouselat = np.nan
        # compute mouse azimuth and elevation for directivity computation
        mouseaz, mouseel = self.get_mouse_azel(xevent, yevent, bbox)
        if self._patterns is not {} and self._app.getpatterncombo() is not '':
            controler = self._patterns[self._app.getpatterncombo()]
            pattern = controler.get_pattern()
            try:
                gain = pattern.directivity(mouselon, mouselat) + \
                    pattern.configure()['cf']
            except TypeError:
                gain = None
        else:
            gain = None
        if mouseaz >= self._zoom.max_azimuth or \
                mouseaz <= self._zoom.min_azimuth or \
                mouseel >= self._zoom.max_elevation or \
                mouseel <= self._zoom.min_elevation:
            mouselon = np.nan
            mouselat = np.nan
            gain = np.nan
        # set status bar text
        app = self.parent().parent()
        app.setmousepos(mouselon, mouselat, gain)

        # if start of zoom is defined set final position
        if self.zoomposorigin is not None:
            mousex, mousey = self.get_mouse_xy(xevent, yevent, bbox)
            self.zoomposfinal = mouseaz, mouseel, \
                mouselon, mouselat, \
                mousex, mousey
            # update and draw patch
            self.zoompatch.set_x(min(mousex, self.zoomposorigin[4]))
            self.zoompatch.set_y(min(mousey, self.zoomposorigin[5]))
            self.zoompatch.set_width(abs(mousex - self.zoomposorigin[4]))
            self.zoompatch.set_height(abs(mousey - self.zoomposorigin[5]))
            self.draw()

        if self.dragorigin is not None:
            deltalon = mouselon - self.dragorigin[0]
            deltalat = mouselat - self.dragorigin[1]
            self._viewer.longitude(self._viewer.longitude() - deltalon)
            self._viewer.latitude(self._viewer.latitude() - deltalat)
            self.draw_elements()
            app = self.parent().parent()
            app.setviewerpos(self._viewer.longitude(), self._viewer.latitude(),
                             self._viewer.altitude())
            self.dragorigin = mouselon, mouselat

    # end of method mouse_move

    def get_mouse_xy(self, xmouse, ymouse, bbox):
        """This function compute x and y in basemap coordinates
        of the mouse given the mouse motion event data.
        """
        # get relative x and y inside the box
        origin_x = bbox.bounds[0]
        origin_y = bbox.bounds[1]
        pixel_width = bbox.bounds[2]
        pixel_height = bbox.bounds[3]
        rel_x = (xmouse - origin_x) / pixel_width
        rel_y = (ymouse - origin_y) / pixel_height
        # get dimensions of the basemap plot and position of the mouse
        map_width = self._earth_map.urcrnrx - self._earth_map.llcrnrx
        map_height = self._earth_map.urcrnry - self._earth_map.llcrnry
        map_x = self._earth_map.llcrnrx + rel_x * map_width
        map_y = self._earth_map.llcrnry + rel_y * map_height
        return map_x, map_y

    # end of function get_mouse_xy

    def get_mouse_ll(self, xmouse, ymouse, bbox):
        """This function compute longitude and latitude of the mouse given
        the mouse motion event data.
        """
        # get relative x and y inside the box
        map_x, map_y = self.get_mouse_xy(xmouse, ymouse, bbox)
        # convert to longitue and latitude
        lon, lat = self._earth_map(map_x, map_y, inverse=True)
        # eliminate out of the Earth cases
        if lon > cst.MAX_LON or lon < cst.MIN_LON:
            lon = np.nan
        if lat > cst.MAX_LAT or lat < cst.MIN_LAT:
            lat = np.nan
        return lon, lat

    # end of function get_mouse_ll

    def get_mouse_azel(self, xmouse, ymouse, bbox):
        """This function compute azimuth and elevation of the mouse given
        the mouse motion event data.
        """
        # get relative x and y inside the box
        origin_x = bbox.bounds[0]
        origin_y = bbox.bounds[1]
        pixel_width = bbox.bounds[2]
        pixel_height = bbox.bounds[3]
        rel_x = (xmouse - origin_x) / pixel_width
        rel_y = (ymouse - origin_y) / pixel_height
        # get azel dimensions of the box
        azimuth_width = self._zoom.max_azimuth - self._zoom.min_azimuth
        elevation_height = self._zoom.max_elevation - self._zoom.min_elevation
        # compute mouse azel
        azimuth = rel_x * azimuth_width + self._zoom.min_azimuth
        elevation = rel_y * elevation_height + self._zoom.min_elevation
        azimuth = min(azimuth, self._zoom.max_azimuth)
        azimuth = max(azimuth, self._zoom.min_azimuth)
        elevation = min(elevation, self._zoom.max_elevation)
        elevation = max(elevation, self._zoom.min_elevation)
        return azimuth, elevation

    # end of function get_mouse_azel

    def mouse_press(self, event):
        """Process mouse click event.
        Buttons Ids:
        1: left-click: start drag and zoom
        2: wheel-click
        3: right-click: recenter plot
        """
        # affectation of action to button id
        action = {
            1: self.mouse_press_zoom,
            2: self.mouse_donothing,
            3: self.mouse_press_drag
        }
        # execution of action
        action[event.button](event)

    # end of method mouse_click

    def mouse_release(self, event):
        """Process mouse release event.
        Buttons Ids:
        1: left-click
        2: wheel-click
        3: right-click
        """
        # affectation of action to button id
        action = {
            1: self.mouse_release_zoom,
            2: self.mouse_donothing,
            3: self.mouse_release_drag
        }
        # execution of action
        action[event.button](event)

    # end of method mouse_click

    def mouse_donothing(self, _):
        """This method do nothing. See usecase in mouse_click method.
        """
        pass

    # end of method mouse_donothing

    def mouse_set_viewer(self, event):
        """Set viewer position by right clicking on the map.
        """
        xevent = event.x
        yevent = event.y
        bbox = event.canvas.figure.axes[0].bbox
        lon, lat = self.get_mouse_ll(xevent, yevent, bbox)
        self._viewer.longitude(round(lon, 1))
        self._viewer.latitude(round(lat, 1))
        self.draw_elements()
        app = self.parent().parent()
        app.setviewerpos(self._viewer.longitude(), self._viewer.latitude(),
                         self._viewer.altitude())

    # end of method mouse_set_viewer

    def mouse_press_zoom(self, event):
        """On event mouse_press, this method is called
        by matplotlib environment.
        It's role is to store the first angle of the rectangular zoom on
        Earth display.
        """
        if event.key == 'control':
            xevent = event.x
            yevent = event.y
            bbox = event.canvas.figure.axes[0].bbox
            az, el = self.get_mouse_azel(xevent, yevent, bbox)
            lon, lat = self.get_mouse_ll(xevent, yevent, bbox)
            x, y = self.get_mouse_xy(xevent, yevent, bbox)
            self.zoomposorigin = az, el, lon, lat, x, y
            self.zoompatch = Rectangle(xy=(x, y),
                                       width=0,
                                       height=0,
                                       fill=False,
                                       linewidth=0.2)
            self._axes.add_patch(self.zoompatch)

    # end of method mouse_press_zoom

    def mouse_release_zoom(self, _):
        """Process mouse release event.
        """
        # if original and final position are defined, zoom the plot
        if self.zoomposorigin is not None and\
           self.zoomposfinal is not None:
            azorigin, elorigin, \
                lonorigin, latorigin, \
                xorigin, yorigin = self.zoomposorigin
            azfinal, elfinal, \
                lonfinal, latfinal, \
                xfinal, yfinal = self.zoomposfinal
            xzoom = abs(xfinal - xorigin) / self.get_width()
            yzoom = abs(yfinal - yorigin) / self.get_height()
            # authorize zooming if bigger than 5% of each axis dimension
            if xzoom > 0.05 and yzoom > 0.05:
                if self._projection == 'nsper':
                    self._zoom.min_azimuth = round(min(azorigin, azfinal), 1)
                    self._zoom.min_elevation = round(min(elorigin, elfinal), 1)
                    self._zoom.max_azimuth = round(max(azorigin, azfinal), 1)
                    self._zoom.max_elevation = round(max(elorigin, elfinal), 1)
                elif self._projection == 'cyl':
                    self._zoom.min_longitude = round(min(lonorigin, lonfinal),
                                                     1)
                    self._zoom.min_latitude = round(min(latorigin, latfinal),
                                                    1)
                    self._zoom.max_longitude = round(max(lonorigin, lonfinal),
                                                     1)
                    self._zoom.max_latitude = round(max(latorigin, latfinal),
                                                    1)
            self.zoomposorigin = None
            self.zoomposfinal = None
            self.zoompatch.remove()
            self.zoompatch = None
            self.draw_elements()
            self.draw_axis()

    # end of method mouse_release_event

    def mouse_press_drag(self, event):
        xevent = event.x
        yevent = event.y
        bbox = event.canvas.figure.axes[0].bbox
        self.dragorigin = self.get_mouse_ll(xevent, yevent, bbox)

    def mouse_release_drag(self, _):
        self.dragorigin = None

    def key_press(self, event):
        """Handle key_press event.
        """
        action = {'escape': self.key_press_esc}
        try:
            action[event.key](event)
        except KeyError:
            pass

    def key_press_esc(self, _):
        """Handle Escape pressed event.
        """
        # abort mouse drag and zoom
        if self.zoomposorigin is not None:
            self.zoomposorigin = None
            self.zoomposfinal = None
            self.zoompatch.remove()
            self.zoompatch = None
            self.draw()

    def draw_elements(self):
        """This method redraw all elements of the earth plot
        """
        utils.trace('in')
        # clear display and reset it
        self._axes.clear()

        # update the zoom
        self.updatezoom()

        # Draw Earth in the background
        self.drawearth(proj=self._projection, resolution=self._resolution)

        # draw all patterns
        at_least_one_slope = False
        for key in self._patterns:
            self._patterns[key].plot()
            if 'display_slope' in self._patterns[key].get_config():
                if self._patterns[key].get_config()['display_slope']:
                    at_least_one_slope = True
        if not at_least_one_slope and len(self._patterns):
            for i in range(len(self._figure.axes)):
                if i:
                    self._figure.delaxes(self._figure.axes[i])

        # draw all Elevation contour
        for element in self._elev:
            self._elev[element].plot()

        # draw stations
        for s in self._stations:
            s.clearplot()
            s.plot()

        # draw polygons
        for p in self._polygons:
            p.clearplot()
            p.plot()

        # draw axis
        self.draw_axis()

        # call to surcharged draw function
        self.draw()
        utils.trace('out')

    # end of draw_elements function

    def draw_axis(self):
        utils.trace('in')
        if self._projection == 'nsper':
            self._axes.set_xlabel('Azimuth (deg)')
            self._axes.set_ylabel('Elevation (deg)')
            # get viewer coordinate in rendering frame
            viewer_x, viewer_y = self._earth_map(self._viewer.longitude(),
                                                 self._viewer.latitude())
            # compute and add x-axis ticks
            azticks = np.arange(self._zoom.min_azimuth,
                                self._zoom.max_azimuth + 0.1, 2)
            self._axes.set_xticks(self.az2x(azticks) + viewer_x)
            self._axes.set_xticklabels('{0:0.1f}'.format(f) for f in azticks)
            # compute and add y-axis ticks
            elticks = np.arange(self._zoom.min_elevation,
                                self._zoom.max_elevation + 0.1, 2)
            self._axes.set_yticks(self.el2y(elticks) + viewer_y)
            self._axes.set_yticklabels('{0:0.1f}'.format(f) for f in elticks)
        elif self._projection == 'cyl':
            self._axes.set_xlabel('Longitude (deg)')
            self._axes.set_ylabel('Latitude (deg)')
            lonticks = np.arange(
                int(self._zoom.min_longitude / 10) * 10,
                self._zoom.max_longitude + 0.1, 20)
            lonticks_converted, _ = np.array(
                self._earth_map(
                    lonticks,
                    np.ones(lonticks.shape) * self._zoom.min_latitude))
            self._axes.set_xticks(lonticks_converted)
            self._axes.set_xticklabels('{0:0.1f}'.format(f) for f in lonticks)
            # compute and add y-axis ticks
            latticks = np.arange(
                int(self._zoom.min_latitude / 10) * 10,
                self._zoom.max_latitude + 0.1, 20)
            _, latticks_converted = np.array(
                self._earth_map(
                    np.ones(latticks.shape) * self._zoom.min_longitude,
                    latticks))
            self._axes.set_yticks(latticks_converted)
            self._axes.set_yticklabels('{0:0.1f}'.format(f) for f in latticks)
        self._axes.tick_params(axis='both', width=0.2)
        self._axes.set_title(self._plot_title)
        utils.trace('out')

    # end of function draw_axis

    def settitle(self, title: str):
        """Set Earth plot title.
        """
        utils.trace('in')
        self._plot_title = title
        self._axes.set_title(self._plot_title)
        utils.trace('out')

    # end of method settitle

    # Change observer Longitude
    def setviewerlongitude(self, lon):
        utils.trace()
        self._viewer.longitude(lon)

    # end of method setviewerlongitude

    # Draw Earth and return Basemap handler
    def drawearth(self, proj='nsper', resolution='c'):
        utils.trace('in')

        ax = self._axes
        # add Earth _earth_map
        # resolution :
        # c: crude
        # l: low
        # i: intermediate
        # h: high
        # f: full
        if self._bluemarble_imshow is not None:
            self._bluemarble_imshow.remove()

        if proj == 'nsper':
            self._earth_map = Basemap(projection='nsper',
                                      llcrnrx=self.llcrnrx,
                                      llcrnry=self.llcrnry,
                                      urcrnrx=self.urcrnrx,
                                      urcrnry=self.urcrnry,
                                      lon_0=self._viewer.longitude(),
                                      lat_0=self._viewer.latitude(),
                                      satellite_height=self._viewer.altitude(),
                                      resolution=resolution,
                                      ax=ax)
            # display Blue Marble picture, projected and cropped
            if self._bluemarble:
                self._bluemarble_imshow = self.croppedbluemarble()
            else:
                self._bluemarble_imshow = None

        elif proj == 'cyl':
            self._earth_map = Basemap(projection=proj,
                                      llcrnrlat=self.llcrnrlat,
                                      urcrnrlat=self.urcrnrlat,
                                      llcrnrlon=self.llcrnrlon,
                                      urcrnrlon=self.urcrnrlon,
                                      lon_0=self._viewer.longitude(),
                                      lat_0=self._viewer.latitude(),
                                      lat_ts=self._viewer.latitude(),
                                      resolution=resolution,
                                      ax=ax)
            if self._bluemarble:
                self._bluemarble_imshow = self._earth_map.bluemarble(scale=0.5)
            else:
                self._bluemarble_imshow = None

        # Earth map drawing options
        # 1. Drawing coast lines
        if self._coastlines_col:
            try:
                # coast lines LineCollection can be remove at once
                self._coastlines_col.remove()
            except ValueError:
                print('drawearth: issue removing coastlines.')
        if self._coastlines != 'no line':
            self._coastlines_col = \
                self._earth_map.drawcoastlines(
                    linewidth=cst.BOLDNESS[self._coastlines])
        # 2. Drawing countries borders
        if self._countries_col:
            try:
                # Country borders LineCollection can be remove at once
                self._countries_col.remove()
            except ValueError:
                print('drawearth: issue removing borders.')
        if self._countries != 'no line':
            self._countries_col = \
                self._earth_map.drawcountries(
                    linewidth=cst.BOLDNESS[self._countries])
        # 3. Drawing parallels
        if self._parallels_col:
            try:
                # Parallels are a dictionary of 2D lines to be
                # removed one by one
                for k in self._parallels_col:
                    self._parallels_col[k][0][0].remove()
                self._parallels_col.clear()
            except ValueError:
                print('drawearth: issue removing parallels.')
        if self._parallels != 'no line':
            self._parallels_col = \
                self._earth_map.drawparallels(
                    np.arange(-80., 81., 20.),
                    linewidth=cst.BOLDNESS[self._parallels])
        # 4. Drawing meridians
        if self._meridians_col:
            try:
                # Meridians are a dictionary of 2D lines to be
                # removed one by one
                for k in self._meridians_col:
                    self._meridians_col[k][0][0].remove()
                self._meridians_col.clear()
            except ValueError:
                print('drawearth: issue removing meridians.')
        if self._meridians != 'no line':
            self._meridians_col = \
                self._earth_map.drawmeridians(
                    np.arange(-180., 181., 20.),
                    linewidth=cst.BOLDNESS[self._meridians])
        # Unconditional drawing of Earth boundary
        self._earth_map.drawmapboundary(linewidth=0.2)

        utils.trace('out')
        return self._earth_map

    # end of drawEarth function

    # Draw isoElevation contours
    def drawelevation(self, level=(10, 20, 30)):
        utils.trace('in')
        # define grid
        iNx = 200
        iNy = 200
        fXlin = np.linspace(self._earth_map.xmin, self._earth_map.xmax, iNx)
        fYlin = np.linspace(self._earth_map.ymin, self._earth_map.ymax, iNy)
        fXMesh, fYMesh = np.meshgrid(fXlin, fYlin)
        fLonMesh, fLatMesh = self._earth_map(fXMesh, fYMesh, inverse=True)
        # define Elevation matrix
        fElev = self.elevation(fLonMesh, fLatMesh)
        csElev = self._earth_map.contour(fXMesh,
                                         fYMesh,
                                         fElev,
                                         level,
                                         colors='black',
                                         linestyles='dotted',
                                         linewidths=0.5)
        utils.trace('out')
        return csElev

    # end of drawelevation

    def elevation(self, stalon, stalat):
        """Compute elevation of spacecraft seen from a station on the ground.
        """
        utils.trace('in')
        # compute phi
        phi = np.arccos(
            np.cos(cst.DEG2RAD * stalat) *
            np.cos(cst.DEG2RAD * (self._viewer.longitude() - stalon)))

        # compute elevation
        elev = np.reshape([
            90 if phi == 0 else cst.RAD2DEG * np.arctan(
                (np.cos(phi) -
                 (cst.EARTH_RAD_EQUATOR_M /
                  (cst.EARTH_RAD_EQUATOR_M + self._viewer.altitude()))) /
                np.sin(phi)) for phi in phi.flatten()
        ], phi.shape)

        # remove station out of view
        elev = np.where(
            np.absolute(stalon - self._viewer.longitude()) < 90, elev, -1)

        utils.trace('out')
        # Return vector
        return elev

    # end of function elevation

    def get_file_key(self, filename):
        utils.trace('in')
        file_index = 1
        if type(filename) is list:
            f = os.path.basename(filename[0])
        else:
            f = os.path.basename(filename)
        file_key = f + ' ' + str(file_index)
        while (file_key in self._patterns) and file_index <= 50:
            file_index = file_index + 1
            file_key = f + ' ' + str(file_index)
        if file_index == 50:
            print(('Max repetition of same file reached.'
                   ' Index 50 will be overwritten'))
        utils.trace('out')
        return file_key

    # end of function get_file_key

    def loadpattern(self, conf=None):
        """Load and display a grd file.
        """
        utils.trace('in')
        try:
            filename = conf['filename']
        except KeyError:
            print('load_pattern:File name is mandatory.')
            utils.trace('out')
            return None
        file_key = self.get_file_key(filename)
        conf['key'] = file_key
        try:
            pattern = PatternControler(parent=self, config=conf)
        except PatternNotCreatedError as pnc:
            print(pnc.__str__())
            utils.trace('out')
            return None
        if 'sat_lon' not in conf:
            dialog = True
            conf['sat_lon'] = self._viewer.longitude()
            conf['sat_lat'] = self._viewer.latitude()
            conf['sat_alt'] = self._viewer.altitude()
        else:
            dialog = False
        pattern.configure(dialog=dialog, config=conf)

        # Add grd in grd dictionary
        self._patterns[file_key] = pattern

        # refresh pattern combo box
        itemlist = ['']
        itemlist.extend(self._patterns.keys())
        self._app.setpatterncombo(itemlist)
        # return pattern controler instance
        utils.trace('out')
        return self._patterns[file_key]

    # end of load_pattern

    # Zoom on the _earth_map
    def updatezoom(self):
        self.llcrnrx = self.az2x(self._zoom.min_azimuth)
        self.llcrnry = self.el2y(self._zoom.min_elevation)
        self.urcrnrx = self.az2x(self._zoom.max_azimuth)
        self.urcrnry = self.el2y(self._zoom.max_elevation)
        self.llcrnrlon = self._zoom.min_longitude
        self.llcrnrlat = self._zoom.min_latitude
        self.urcrnrlon = self._zoom.max_longitude
        self.urcrnrlat = self._zoom.max_latitude
        self.centerx = (self.llcrnrx + self.urcrnrx) / 2
        self.centery = (self.llcrnry + self.urcrnry) / 2
        self.cntrlon = (self.llcrnrlon + self.urcrnrlon) / 2
        self.cntrlat = (self.llcrnrlat + self.urcrnrlat) / 2

    # end of method updatezoom

    def get_width(self):
        if self._projection == 'nsper':
            return self.urcrnrx - self.llcrnrx
        elif self._projection == 'cyl':
            return self.urcrnrlon - self.llcrnrlon
        else:
            return 0

    # end of function get_width

    def get_height(self):
        if self._projection == 'nsper':
            return self.urcrnry - self.llcrnry
        elif self._projection == 'cyl':
            return self.urcrnrlat - self.llcrnrlat
        else:
            return 0

    # end of function get_width

    # convert Azimuth to _earth_map.x
    def az2x(self, az):
        return np.tan(az * cst.DEG2RAD) * self._viewer.altitude()

    # convert Elevation to _earth_map.y
    def el2y(self, el):
        return np.tan(el * cst.DEG2RAD) * self._viewer.altitude()

    def x2az(self, x):
        return np.arctan2(x, self._viewer.altitude()) * cst.RAD2DEG

    def y2el(self, y):
        return np.arctan2(y, self._viewer.altitude()) * cst.RAD2DEG

    def projection(self, proj: str = None):
        """This function allows access to attribute _projection.
        """
        utils.trace('in')
        if proj:
            if proj == 'nsper' or proj == 'cyl':
                self._projection = proj
            else:
                raise ValueError("Projection is either 'nsper' or 'cyl'.")
        utils.trace('out')
        return self._projection

    # end of function projection

    def set_resolution(self, resolution: str = 'c'):
        """Set Earth map resolution.
        """
        utils.trace('in')
        self._resolution = resolution
        self._earth_map.resolution = self._resolution
        self.draw_elements()
        utils.trace('out')

    # end of function set_resolution

    def save(self, filename=None):
        """Save the plot with given filename. If file name not provided,
        use last used name.
        """
        utils.trace('in')
        # store file name for future call to this function
        if filename:
            self.filename = filename
        # save plot into file
        # plt.savefig(self.filename, dpi='figure')
        self.print_figure(self.filename)
        utils.trace('out')

    # end of function save

    ###################################################################
    #
    #       Getters and Setters
    #
    ###################################################################

    def viewer(self, v=None):
        """Get _viewer attribute.
        """
        if v is not None:
            self._viewer = v
        return self._viewer

    def zoom(self, z=None):
        """Get _zoom attribute.
        """
        if z is not None:
            self._zoom = z
        return self._zoom

    def get_coastlines(self):
        """Return value of private attribute _coastlines
        """
        return self._coastlines

    # end of function get_coastlines

    def set_coastlines(self, c: str, refresh: bool = False):
        """Set private attribute _coastlines value.
        If refresh is True, redraw Earth.
        Return the value passed to the function.
        """
        utils.trace('in')
        self._coastlines = c
        if refresh:
            self.drawearth(proj=self._projection, resolution=self._resolution)
            self.draw_axis()
            self.draw()
        utils.trace('out')
        return self._coastlines

    # end of function set_coastlines

    def get_countries(self):
        """Return value of private attribute _countries
        """
        return self._countries

    # end of function get_countries

    def set_countries(self, c: str, refresh: bool = False):
        """Set private attribute _countries value.
        If refresh is True, redraw Earth.
        Return the value passed to the function.
        """
        utils.trace('in')
        self._countries = c
        if refresh:
            self.drawearth(proj=self._projection, resolution=self._resolution)
            self.draw_axis()
            self.draw()
        utils.trace('out')
        return self._countries

    # end of function set_countries

    def get_parallels(self):
        """Return the value of private attribute _parallels
        """
        return self._parallels

    # end of function get_parallels

    def set_parallels(self, p: str, refresh: bool = False):
        """Set the value of private attribute _parallels.
        If refresh is True, redraw Earth.
        Return the value passed to the function.
        """
        utils.trace('in')
        self._parallels = p
        if refresh:
            self.drawearth(proj=self._projection, resolution=self._resolution)
            self.draw_axis()
            self.draw()
        utils.trace('out')
        return self._parallels

    # end of function set_parallels

    def get_meridians(self):
        """Return the value of private attribute _meridians.
        """
        return self._meridians

    # end of function get_meridians

    def set_meridians(self, m: str, refresh: bool = False):
        """Set the value of the private attribute _meridians.
        If refresh is True, redraw Earth.
        Return the value passed to the function.
        """
        utils.trace('in')
        self._meridians = m
        if refresh:
            self.drawearth(proj=self._projection, resolution=self._resolution)
            self.draw_axis()
            self.draw()
        utils.trace('out')
        return self._meridians

    # end of function set_meridians

    def get_centralwidget(self):
        """Accessor to central widget.
        """
        return self._centralwidget

    # end of get_centralwidget

    def croppedbluemarble(self):
        # get blue marble data projected on the current projection
        im = self._earth_map.bluemarble(alpha=0.9, scale=0.5)
        data = im.get_array()

        # get data array dimension
        nx, ny, _ = data.shape
        ead = self.earth_angular_diameter()
        stepx = ead / (nx - 1)
        stepy = ead / (ny - 1)

        # create new matrix to the dimension of the current plot
        azmin = self._zoom.min_azimuth
        azmax = self._zoom.max_azimuth
        elmin = self._zoom.min_elevation
        elmax = self._zoom.max_elevation
        new_nx = int((azmax - azmin) / stepx / 2) * 2 + 1
        new_ny = int((elmax - elmin) / stepy / 2) * 2 + 1
        new_data = np.zeros((new_ny, new_nx, 4))

        # compute first azimuth index of source array and destination array
        x0_source = 0
        x0_destination = 0
        if azmin > -ead / 2:
            # crop in azimuth
            x0_source = int(np.abs(azmin + ead / 2) / stepx)
        else:
            x0_destination = int(np.abs(azmin + ead / 2) / stepx)
        # if destination array smaller than origin array, limit source array
        if new_nx - x0_destination < nx - x0_source:
            x_source = range(x0_source, x0_source + new_nx - x0_destination)
        else:
            x_source = range(x0_source, nx)
        x_destination = range(x0_destination, x0_destination + len(x_source))

        # compute first elevation index of source array and destination array
        y0_source = 0
        y0_destination = 0
        if elmin > -ead / 2:
            # crop in azimuth
            y0_source = int(np.abs(elmin + ead / 2) / stepy)
        else:
            y0_destination = int(np.abs(elmin + ead / 2) / stepy)
        # if destination array smaller than origin array, limit source array
        if new_ny - y0_destination < ny - y0_source:
            y_source = range(y0_source, y0_source + new_ny - y0_destination)
        else:
            y_source = range(y0_source, ny)
        y_destination = range(y0_destination, y0_destination + len(y_source))

        x0_src = x_source[0]
        x1_src = x_source[-1]
        y0_src = y_source[0]
        y1_src = y_source[-1]
        x0_des = x_destination[0]
        x1_des = x_destination[-1]
        y0_des = y_destination[0]
        y1_des = y_destination[-1]
        new_data[y0_des:y1_des, x0_des:x1_des] = \
            data[y0_src:y1_src, x0_src:x1_src]
        im.set_array(new_data)
        return im

    # end of method croppedbluemarble

    def earth_angular_diameter(self):
        """Compute Earth anguar diameter from spacecraft point of view
        depending on the altitude.
        """
        sat_height = cst.EARTH_RAD_BASEMAP + self._viewer.altitude()
        d = 2 * np.arcsin(cst.EARTH_RAD_BASEMAP / sat_height) * cst.RAD2DEG
        return d

    # end of function earth_angular_diameter

    def get_earthmap(self):
        return self._earth_map

    # end of function get_earthmap

    def get_axes(self):
        return self._axes

    # end of function get_axes

    def bluemarble(self, set=None):
        if set is not None:
            if set:
                self._bluemarble = True
            else:
                self._bluemarble = False
        return self._bluemarble
Ejemplo n.º 14
0
def makeHoldBiteMovieLabeled(movInfo,
                             frontLabels,
                             sideLabels,
                             saveToFolder,
                             frontWindow=None,
                             sideWindow=None,
                             codec='MJPG',
                             ext='.avi'):
    plt.close('all')
    if frontWindow == None:
        try:
            frontWindow = tuple(movInfo['FrontMovie']['Window'])
        except:
            frontWindow = (0, 0, 160, 120)
    if sideWindow == None:
        try:
            sideWindow = tuple(movInfo['SideMovie']['Window'])
        except:
            sideWindow = (0, 0, 160, 120)

    baseName = os.path.basename(movInfo['SessionFile'])
    sessionStr, jext = os.path.splitext(baseName)
    fullMovPath = os.path.join(saveToFolder, sessionStr + ext)

    hasFront = len(movInfo['FrontMovie']['Folder']) > 0
    hasSide = len(movInfo['SideMovie']['Folder']) > 0
    hasLog = len(movInfo['BiteTraces']['LogFile']) > 0

    if hasFront and hasSide:
        numVids = 2
    elif hasFront ^ hasSide:  # either but not both
        numVids = 1
    else:
        raise Exception('Must have at least one video to label.')

    hasFrontLabels = (frontLabels != None)
    hasSideLabels = (sideLabels != None)

    bodyPartsColsF = {
        'leftPupil': (180, 230, 245, 128),
        'rightPupil': (0, 71, 160, 128),
        'nose': (36, 211, 154, 128)
    }
    bodyPartsColsS = {
        'nose': (36, 211, 154, 128),
        'tearDuct': (180, 230, 245, 128),
        'pupil': (205, 2, 43, 128),
        'backCorner': (0, 71, 160, 128)
    }

    plotHeight = 480
    plotWidth = 640
    him = 240
    wim = 320

    (hf, wf) = (plotHeight, plotWidth + 2 * wim)

    outputFrame = np.zeros((hf, wf, 3)).astype(np.uint8)
    plotIm = np.zeros((plotHeight, plotWidth, 3)).astype(np.uint8)

    timestampPos = (10, 10)
    bitePos = ()

    fontSize = 20
    axLabelFontSize = 15
    axLabelFontName = 'DejaVu Sans'
    font = {'family': 'sans-serif', 'weight': 'light', 'size': 18}

    mpl.rc('font', **font)
    font = ImageFont.truetype(
        os.path.join(Path().absolute(), 'Roboto', 'Roboto-Black.ttf'),
        fontSize)
    timestampCol = (255, 255, 255)
    biteCol = 'forestgreen'
    waterCol = 'darkturquoise'
    timeoutCol = 'crimson'

    frontImTrace = Image.new(
        'RGB',
        (frontWindow[2] - frontWindow[0], frontWindow[3] - frontWindow[1]))
    sideImTrace = Image.new(
        'RGB', (sideWindow[2] - sideWindow[0], sideWindow[3] - sideWindow[1]))
    drawObjFTrace = ImageDraw.Draw(frontImTrace)
    drawObjSTrace = ImageDraw.Draw(sideImTrace)
    frontImTraceResized = Image.new('RGB', (wim, him))
    sideImTraceResized = Image.new('RGB', (wim, him))
    drawObjFTraceResized = ImageDraw.Draw(frontImTraceResized)
    drawObjSTraceResized = ImageDraw.Draw(sideImTraceResized)
    yinc = fontSize
    ypos = 0
    if hasFrontLabels:
        prevXYF = {}
        for part in bodyPartsColsF:
            drawObjFTrace.text((0, ypos),
                               part,
                               fill=bodyPartsColsF[part],
                               font=font)
            ypos += yinc
            prevXYF.update({part: {'x': None, 'y': None, 'likelihood': None}})
    ypos = 0
    if hasSideLabels:
        prevXYS = {}
        for part in bodyPartsColsS:
            drawObjSTrace.text((0, ypos),
                               part,
                               fill=bodyPartsColsS[part],
                               font=font)
            ypos += yinc
            prevXYS.update({part: {'x': None, 'y': None, 'likelihood': None}})

    # initialize the FourCC, video writer, dimensions of the frame, and
    fourcc = cv2.VideoWriter_fourcc(*codec)
    print('Save to:', fullMovPath)
    writer = cv2.VideoWriter(fullMovPath, fourcc, movInfo['FrameRate'],
                             (wf, hf))

    # make plots
    plotFig, ax = plt.subplots(3, 1)  #,dpi=200, figsize=(640/200, 480/200))
    canvas = FigureCanvas(plotFig)
    bbox = plotFig.get_window_extent().transformed(
        plotFig.dpi_scale_trans.inverted())
    width, height = bbox.width, bbox.height
    width *= plotFig.dpi
    height *= plotFig.dpi

    windAx = ax[0]
    fullAx = ax[1]
    avgAx = ax[2]

    windAx.tick_params(bottom=False, left=False)
    fullAx.tick_params(bottom=False, left=False)
    avgAx.tick_params(bottom=False, left=False)
    windAx.set_frame_on(False)
    fullAx.set_frame_on(False)
    avgAx.set_frame_on(False)

    windowSize = 20
    biteWindow = 1

    allBiteTimes = movInfo['AllBiteTimes']
    try:
        allBiteDurs = allBiteTimes[:, 1] - allBiteTimes[:, 0]
    except:
        allBiteDurs = allBiteTimes[1] - allBiteTimes[0]
        allBiteDurs = np.expand_dims(allBiteDurs, axis=0)
        allBiteDurs = allBiteTimes[:, 1] - allBiteTimes[:, 0]

    avgBiteDur = movmean(allBiteDurs, windowSize)
    rewardTimes = movInfo['RewardTimes']
    if len(rewardTimes.shape) < 2:
        rewardTimes = np.expand_dims(rewardTimes, axis=0)
    rewardTimes = np.array(rewardTimes)

    timeoutTimes = movInfo['TimeoutTimes']
    if len(timeoutTimes.shape) < 2:
        timeoutTimes = np.expand_dims(timeoutTimes, axis=0)
    timeoutTimes = np.array(timeoutTimes)
    avgAx.set_xlabel('Time (s)',
                     fontsize=axLabelFontSize,
                     fontname=axLabelFontName,
                     fontweight='light')
    avgAx.set_ylabel('Avg Duration (s)',
                     fontsize=axLabelFontSize,
                     fontname=axLabelFontName,
                     fontweight='light')
    avgAx.set_xlim(0, movInfo['SessionDuration'])
    #

    fullAx.set_ylim(-0.1, 3.1)
    fullAx.set_xlim(0, movInfo['SessionDuration'])
    fullAx.set_yticks([0.5, 1.5, 2.5])
    fullAx.set_yticklabels(['Timeout', 'Water', 'Biting'],
                           fontsize=axLabelFontSize,
                           fontname=axLabelFontName,
                           fontweight='light')
    fullAx.set_xticks([])
    fullAx.set_xticklabels([])

    windAx.set_title(' '.join([
        movInfo['Subject'],
        'HoldTime: %03d ms on' % movInfo['HoldTime'], movInfo['DateString'],
        'at', movInfo['TimeString']
    ]))
    windAx.set_ylim(-0.1, 3.1)
    windAx.set_xlim(0, biteWindow)
    windAx.set_yticks([0.5, 1.5, 2.5])
    windAx.set_yticklabels(['Timeout', 'Water', 'Biting'],
                           fontsize=axLabelFontSize,
                           fontname=axLabelFontName,
                           fontweight='light')

    biteDurLine = avgAx.plot(allBiteTimes[:, 0],
                             avgBiteDur,
                             color='royalblue',
                             linewidth=1,
                             label='Avg Bite Duration (last %d bites)' %
                             windowSize)

    avgAxYLim = avgAx.get_ylim()
    numBites = len(avgBiteDur)
    numRewards = len(avgBiteDur)
    numTimeouts = len(avgBiteDur)

    biteBoxes = []
    timeoutBoxes = []
    waterBoxes = []

    #bite boxes
    for bite in range(numBites):
        biteOn = allBiteTimes[bite, 0]
        biteOff = allBiteTimes[bite, 1]
        biteDur = biteOff - biteOn
        rect = Rectangle((biteOn, 2.1), biteDur, 0.8)
        biteBoxes.append(rect)

    #water boxes
    for reward in rewardTimes:
        rewardOn = reward[0]
        rewardOff = reward[1]
        rewardDur = rewardOff - rewardOn
        rect = Rectangle((rewardOn, 1.1), rewardDur, 0.8)
        waterBoxes.append(rect)

    #timeout boxes
    for timeout in timeoutTimes:
        timeoutOn = timeout[0]
        timeoutOff = timeout[1]

        timeoutDur = timeoutOff - timeoutOn
        rect = Rectangle((timeoutOn, 0.1), timeoutDur, 0.8)
        timeoutBoxes.append(rect)

    biteCollectionFull = PatchCollection(biteBoxes,
                                         facecolor=biteCol,
                                         edgecolor='None')
    biteCollectionWind = PatchCollection(biteBoxes,
                                         facecolor=biteCol,
                                         edgecolor='None')

    waterCollectionFull = PatchCollection(waterBoxes,
                                          facecolor=waterCol,
                                          edgecolor='None')
    waterCollectionWind = PatchCollection(waterBoxes,
                                          facecolor=waterCol,
                                          edgecolor='None')

    timeoutCollectionFull = PatchCollection(timeoutBoxes,
                                            facecolor=timeoutCol,
                                            edgecolor='None')
    timeoutCollectionWind = PatchCollection(timeoutBoxes,
                                            facecolor=timeoutCol,
                                            edgecolor='None')

    # Add collection to axes
    windAx.add_collection(biteCollectionWind)
    fullAx.add_collection(biteCollectionFull)
    windAx.add_collection(waterCollectionWind)
    fullAx.add_collection(waterCollectionFull)
    windAx.add_collection(timeoutCollectionWind)
    fullAx.add_collection(timeoutCollectionFull)

    if hasLog:
        fullAx.plot(movInfo['BiteTraces']['Timestamps'],
                    movInfo['BiteTraces']['Analog'] + 2,
                    linewidth=0.5,
                    color='gold')
        fullAx.plot(movInfo['BiteTraces']['Timestamps'],
                    movInfo['BiteTraces']['Digital'] + 2,
                    linewidth=0.5,
                    color='mediumpurple')
        windAx.plot(movInfo['BiteTraces']['Timestamps'],
                    movInfo['BiteTraces']['Analog'] + 2,
                    linewidth=2,
                    color='gold',
                    label='Sensor')
        windAx.plot(movInfo['BiteTraces']['Timestamps'],
                    movInfo['BiteTraces']['Digital'] + 2,
                    linewidth=2,
                    color='mediumpurple',
                    label='Bpod')
        windAx.legend(loc='lower right')

    plotFig.set_size_inches(640 / plotFig.dpi, 480 / plotFig.dpi)

    canvas.draw()

    labels = avgAx.get_xticklabels(which='major')
    labs = [label.get_text() for label in labels]
    avgAx.set_xticklabels(labs, fontsize=axLabelFontSize, fontweight=1)

    windAx.set_xlim(-0.1, biteWindow + 0.1)
    lastBiteIdx = 0
    lcutoffF = 0.2
    lcutoffS = 0.2
    markerRadius = 2

    for fidx, frameTime in enumerate(movInfo['FrameTimes']):
        windAx.set_xlim(frameTime - 0.1, frameTime + biteWindow + 0.1)

        fullRect = Rectangle((frameTime, 0),
                             biteWindow,
                             3,
                             fill=None,
                             edgecolor='k',
                             linewidth=2)
        fullAx.add_patch(fullRect)

        windRect = Rectangle((frameTime, 0),
                             biteWindow,
                             3,
                             fill=None,
                             edgecolor='k',
                             linewidth=2)
        windAx.add_patch(windRect)
        bdl = '%03d ms' % int(1000 * avgBiteDur[lastBiteIdx])
        avl = avgAx.plot([frameTime, frameTime], [0, 1],
                         color='red',
                         linewidth=2,
                         label=bdl)
        avgAx.set_ylim(avgAxYLim)
        avleg = avgAx.legend(loc='lower right', fontsize=axLabelFontSize - 2)

        if allBiteTimes[lastBiteIdx, 0] <= frameTime:
            if lastBiteIdx < len(allBiteTimes[:, 0]) - 1:
                lastBiteIdx += 1

        canvas.draw()  # draw the canvas, cache the renderer

        s, (width, height) = canvas.print_to_buffer()
        plotIm = np.fromstring(canvas.tostring_rgb(),
                               dtype='uint8').reshape(height, width, 3)
        avl[0].remove()
        fullRect.remove()
        windRect.remove()
        del avl
        avleg.remove()
        #plotImArr = np.array(plotFig.canvas.renderer.buffer_rgba())
        if fidx == 0:
            plt.figure()
            plt.imshow(plotIm)
        #draw biting/water
        biting = movInfo['BiteFrames'][fidx]
        deliveringReward = movInfo['WaterFrames'][fidx]

        if not hasFront:
            frontIm = Image.new('RGB', (frontWindow[2] - frontWindow[0],
                                        frontWindow[3] - frontWindow[1]))
            drawObjF = ImageDraw.Draw(frontIm)
        else:
            imfPath = os.path.join(movInfo['FrontMovie']['Folder'],
                                   movInfo['FrontMovie']['ImageNames'][fidx])
            imfGray = Image.open(imfPath)
            imfGray = imfGray.crop(frontWindow)
            imfGray = imfGray.rotate(180)
            frontIm = imfGray.convert('RGB')
            drawObjF = ImageDraw.Draw(frontIm)

            #draw timestamp
            ftf = movInfo['FrontMovie']['Timestamps'][fidx]
            ftfStr = formatTimeStr(ftf)
            textWidth, textHeight = drawObjF.textsize(ftfStr, font=font)
            timestampPos = (wim - textWidth, 0)
            drawObjF.text(timestampPos, ftfStr, fill=timestampCol, font=font)

            #draw Labels
            for part in frontLabels:
                x = frontLabels[part]['x'][fidx]
                y = frontLabels[part]['y'][fidx]
                l = frontLabels[part]['likelihood'][fidx]
                ladd = np.log10(l)
                xyF = [
                    int(x - (markerRadius + ladd)),
                    int(y - (markerRadius + ladd)),
                    int(x + markerRadius + ladd),
                    int(y + (markerRadius + ladd))
                ]

                if prevXYF[part][
                        'x'] is None:  #part hasn't had a coordinate yet
                    if l > lcutoffF and biting:
                        prevXYF[part].update({'x': int(x), 'y': int(y)})
                else:
                    if l > lcutoffF and biting:
                        xyfline = [
                            prevXYF[part]['x'], prevXYF[part]['y'],
                            int(x),
                            int(y)
                        ]
                        drawObjFTrace.line(xyfline,
                                           fill=bodyPartsColsF[part],
                                           width=1)
                        prevXYF[part].update({'x': int(x), 'y': int(y)})

                if l > lcutoffF:
                    drawObjF.ellipse(xyF,
                                     fill=bodyPartsColsF[part],
                                     outline=None)
                    print(np.array(frontImTrace).shape)

        if not hasSide:
            sideIm = Image.new(
                'RGB',
                (sideWindow[2] - sideWindow[0], sideWindow[3] - sideWindow[1]))
            drawObjS = ImageDraw.Draw(sideIm)
        else:
            imsPath = os.path.join(movInfo['SideMovie']['Folder'],
                                   movInfo['SideMovie']['ImageNames'][fidx])
            imsGray = Image.open(imsPath)
            imsGray = imsGray.crop(sideWindow)
            imsGray = imsGray.rotate(180)
            sideIm = imsGray.convert('RGB')
            drawObjS = ImageDraw.Draw(sideIm)

            #draw timestamp
            fts = movInfo['SideMovie']['Timestamps'][fidx]
            ftsStr = formatTimeStr(fts)
            textWidth, textHeight = drawObjS.textsize(ftsStr, font=font)
            timestampPos = (wim - textWidth, 0)
            drawObjS.text(timestampPos, ftsStr, fill=timestampCol, font=font)

            #draw Labels
            for part in sideLabels:
                x = sideLabels[part]['x'][fidx]
                y = sideLabels[part]['y'][fidx]
                l = sideLabels[part]['likelihood'][fidx]
                ladd = np.log10(l)
                xyS = [
                    int(x - (markerRadius + ladd)),
                    int(y - (markerRadius + ladd)),
                    int(x + markerRadius + ladd),
                    int(y + (markerRadius + ladd))
                ]

                if prevXYS[part][
                        'x'] is None:  #part hasn't had a coordinate yet
                    if l > lcutoffS and biting:
                        prevXYS[part].update({'x': int(x), 'y': int(y)})
                else:
                    if l > lcutoffS and biting:
                        xysline = [
                            prevXYS[part]['x'], prevXYS[part]['y'],
                            int(x),
                            int(y)
                        ]
                        drawObjSTrace.line(xysline,
                                           fill=bodyPartsColsS[part],
                                           width=1)
                        prevXYS[part].update({'x': int(x), 'y': int(y)})

                if l > lcutoffS:
                    drawObjS.ellipse(xyS,
                                     fill=bodyPartsColsS[part],
                                     outline=None)
                    print(np.array(frontImTrace).shape)

        #draw biting/water
        biting = movInfo['BiteFrames'][fidx]
        deliveringReward = movInfo['WaterFrames'][fidx]
        if biting:
            drawObjF.text((0, 0), 'BITING', fill=biteCol, font=font)
        if deliveringReward:
            drawObjF.text((0, fontSize + 1), 'WATER', fil=waterCol, font=font)
        frontIm = frontIm.resize((wim, him))
        frontImTraceResized = frontImTrace.resize((wim, him))
        frontOut = np.array(frontIm).astype(np.uint8)
        sideIm = sideIm.resize((wim, him))
        sideImTraceResized = sideImTrace.resize((wim, him))
        sideOut = np.array(sideIm).astype(np.uint8)

        outputFrame[0:hf, 0:plotWidth, :] = np.flip(plotIm, axis=2)

        outputFrame[0:him, plotWidth:plotWidth + wim, :] = np.flip(frontOut,
                                                                   axis=2)
        outputFrame[him:2 * him,
                    plotWidth:plotWidth + wim, :] = np.flip(sideOut, axis=2)

        outputFrame[0:him, plotWidth + wim:plotWidth + 2 * wim, :] = np.flip(
            frontImTraceResized, axis=2)
        outputFrame[him:2 * him, plotWidth + wim:plotWidth +
                    2 * wim, :] = np.flip(sideImTraceResized, axis=2)

        cv2.imshow("OutputFrame", outputFrame)
        writer.write(outputFrame)
        # show the frames
        key = cv2.waitKey(30)
        # if the `q` key was pressed, break from the loop
        if key == 27:
            break

    # do a bit of cleanup
    print("[INFO] cleaning up...")
    cv2.destroyAllWindows()
    writer.release()
Ejemplo n.º 15
0
def makeHoldBiteMovie(movInfo,
                      saveToFolder,
                      frontWindow=None,
                      sideWindow=None,
                      codec='MJPG',
                      ext='.avi'):
    plt.close('all')
    if frontWindow == None:
        try:
            frontWindow = tuple(movInfo['FrontMovie']['Window'])
        except:
            frontWindow = (0, 0, 160, 120)
    if sideWindow == None:
        try:
            sideWindow = tuple(movInfo['SideMovie']['Window'])
        except:
            sideWindow = (0, 0, 160, 120)

    baseName = os.path.basename(movInfo['SessionFile'])
    sessionStr, jext = os.path.splitext(baseName)
    fullMovPath = os.path.join(saveToFolder, sessionStr + ext)
    print('Save to:', fullMovPath)

    hasFront = len(movInfo['FrontMovie']['Folder']) > 0
    hasSide = len(movInfo['SideMovie']['Folder']) > 0
    hasLog = len(movInfo['BiteTraces']['LogFile']) > 0

    if hasFront and hasSide:
        numVids = 2
    elif hasFront ^ hasSide:  # either but not both
        numVids = 1
    else:
        numVids = 0

    plotHeight = 480
    plotWidth = 640
    him = 0
    wim = 0
    if numVids > 0:
        him = 240
        wim = 320

    (hf, wf) = (plotHeight + him, plotWidth)

    outputFrame = np.zeros((hf, wf, 3)).astype(np.uint8)
    plotIm = np.zeros((480, 640, 3)).astype(np.uint8)

    timestampPos = (10, 10)
    bitePos = ()

    fontSize = 20
    axLabelFontSize = 15
    axLabelFontName = 'DejaVu Sans'
    font = {'family': 'sans-serif', 'weight': 'light', 'size': 18}

    mpl.rc('font', **font)
    font = ImageFont.truetype(
        os.path.join(Path().absolute(), 'Roboto', 'Roboto-Black.ttf'),
        fontSize)
    timestampCol = (255, 255, 255)
    biteCol = 'forestgreen'
    waterCol = 'darkturquoise'
    timeoutCol = 'crimson'

    # initialize the FourCC, video writer, dimensions of the frame, and
    fourcc = cv2.VideoWriter_fourcc(*codec)
    writer = cv2.VideoWriter(fullMovPath, fourcc, movInfo['FrameRate'],
                             (wf, hf))
    print('Writer size:', wf, hf)

    # make plots
    plotFig, ax = plt.subplots(3, 1)  #,dpi=200, figsize=(640/200, 480/200))
    canvas = FigureCanvas(plotFig)
    bbox = plotFig.get_window_extent().transformed(
        plotFig.dpi_scale_trans.inverted())
    width, height = bbox.width, bbox.height
    width *= plotFig.dpi
    height *= plotFig.dpi

    windAx = ax[0]
    fullAx = ax[1]
    avgAx = ax[2]

    windAx.tick_params(bottom=False, left=False)
    fullAx.tick_params(bottom=False, left=False)
    avgAx.tick_params(bottom=False, left=False)
    windAx.set_frame_on(False)
    fullAx.set_frame_on(False)
    avgAx.set_frame_on(False)

    windowSize = 20
    biteWindow = 1

    allBiteTimes = movInfo['AllBiteTimes']
    try:
        allBiteDurs = allBiteTimes[:, 1] - allBiteTimes[:, 0]
    except:
        allBiteDurs = allBiteTimes[1] - allBiteTimes[0]
        allBiteDurs = np.expand_dims(allBiteDurs, axis=0)
        allBiteDurs = allBiteTimes[:, 1] - allBiteTimes[:, 0]

    avgBiteDur = movmean(allBiteDurs, windowSize)
    rewardTimes = movInfo['RewardTimes']
    if len(rewardTimes.shape) < 2:
        rewardTimes = np.expand_dims(rewardTimes, axis=0)
    rewardTimes = np.array(rewardTimes)

    timeoutTimes = movInfo['TimeoutTimes']
    if len(timeoutTimes.shape) < 2:
        timeoutTimes = np.expand_dims(timeoutTimes, axis=0)
    timeoutTimes = np.array(timeoutTimes)
    avgAx.set_xlabel('Time (s)',
                     fontsize=axLabelFontSize,
                     fontname=axLabelFontName,
                     fontweight='light')
    avgAx.set_ylabel('Avg Duration (s)',
                     fontsize=axLabelFontSize,
                     fontname=axLabelFontName,
                     fontweight='light')
    avgAx.set_xlim(0, movInfo['SessionDuration'])
    #

    fullAx.set_ylim(-0.1, 3.1)
    fullAx.set_xlim(0, movInfo['SessionDuration'])
    fullAx.set_yticks([0.5, 1.5, 2.5])
    fullAx.set_yticklabels(['Timeout', 'Water', 'Biting'],
                           fontsize=axLabelFontSize,
                           fontname=axLabelFontName,
                           fontweight='light')
    fullAx.set_xticks([])
    fullAx.set_xticklabels([])

    windAx.set_title(' '.join([
        movInfo['Subject'],
        'HoldTime: %03d ms on' % movInfo['HoldTime'], movInfo['DateString'],
        'at', movInfo['TimeString']
    ]))
    windAx.set_ylim(-0.1, 3.1)
    windAx.set_xlim(0, biteWindow)
    windAx.set_yticks([0.5, 1.5, 2.5])
    windAx.set_yticklabels(['Timeout', 'Water', 'Biting'],
                           fontsize=axLabelFontSize,
                           fontname=axLabelFontName,
                           fontweight='light')

    biteDurLine = avgAx.plot(allBiteTimes[:, 0],
                             avgBiteDur,
                             color='royalblue',
                             linewidth=1,
                             label='Avg Bite Duration (last %d bites)' %
                             windowSize)

    avgAxYLim = avgAx.get_ylim()
    numBites = len(avgBiteDur)
    numRewards = len(avgBiteDur)
    numTimeouts = len(avgBiteDur)

    biteBoxes = []
    timeoutBoxes = []
    waterBoxes = []

    #bite boxes
    for bite in range(numBites):
        biteOn = allBiteTimes[bite, 0]
        biteOff = allBiteTimes[bite, 1]
        biteDur = biteOff - biteOn
        rect = Rectangle((biteOn, 2.1), biteDur, 0.8)
        biteBoxes.append(rect)

    #water boxes
    for reward in rewardTimes:
        rewardOn = reward[0]
        rewardOff = reward[1]
        rewardDur = rewardOff - rewardOn
        rect = Rectangle((rewardOn, 1.1), rewardDur, 0.8)
        waterBoxes.append(rect)

    #timeout boxes
    for timeout in timeoutTimes:
        timeoutOn = timeout[0]
        timeoutOff = timeout[1]

        timeoutDur = timeoutOff - timeoutOn
        rect = Rectangle((timeoutOn, 0.1), timeoutDur, 0.8)
        timeoutBoxes.append(rect)

    biteCollectionFull = PatchCollection(biteBoxes,
                                         facecolor=biteCol,
                                         edgecolor='None')
    biteCollectionWind = PatchCollection(biteBoxes,
                                         facecolor=biteCol,
                                         edgecolor='None')

    waterCollectionFull = PatchCollection(waterBoxes,
                                          facecolor=waterCol,
                                          edgecolor='None')
    waterCollectionWind = PatchCollection(waterBoxes,
                                          facecolor=waterCol,
                                          edgecolor='None')

    timeoutCollectionFull = PatchCollection(timeoutBoxes,
                                            facecolor=timeoutCol,
                                            edgecolor='None')
    timeoutCollectionWind = PatchCollection(timeoutBoxes,
                                            facecolor=timeoutCol,
                                            edgecolor='None')

    # Add collection to axes
    windAx.add_collection(biteCollectionWind)
    fullAx.add_collection(biteCollectionFull)
    windAx.add_collection(waterCollectionWind)
    fullAx.add_collection(waterCollectionFull)
    windAx.add_collection(timeoutCollectionWind)
    fullAx.add_collection(timeoutCollectionFull)

    if hasLog:
        fullAx.plot(movInfo['BiteTraces']['Timestamps'],
                    movInfo['BiteTraces']['Analog'] + 2,
                    linewidth=0.5,
                    color='gold')
        fullAx.plot(movInfo['BiteTraces']['Timestamps'],
                    movInfo['BiteTraces']['Digital'] + 2,
                    linewidth=0.5,
                    color='mediumpurple')
        windAx.plot(movInfo['BiteTraces']['Timestamps'],
                    movInfo['BiteTraces']['Analog'] + 2,
                    linewidth=2,
                    color='gold',
                    label='Sensor')
        windAx.plot(movInfo['BiteTraces']['Timestamps'],
                    movInfo['BiteTraces']['Digital'] + 2,
                    linewidth=2,
                    color='mediumpurple',
                    label='Bpod')
        windAx.legend(loc='lower right')

    plotFig.set_size_inches(640 / plotFig.dpi, 480 / plotFig.dpi)

    canvas.draw()

    labels = avgAx.get_xticklabels(which='major')
    labs = [label.get_text() for label in labels]
    avgAx.set_xticklabels(labs, fontsize=axLabelFontSize, fontweight=1)

    windAx.set_xlim(-0.1, biteWindow + 0.1)
    #     wl = avgAx.get_xticklabels(which='major')
    #     windAx.set_xticklabels([label.get_text() for label in wl], fontsize=axLabelFontSize, fontname=axLabelFontName, fontweight='light')
    lastBiteIdx = 0

    for fidx, frameTime in enumerate(movInfo['FrameTimes']):
        windAx.set_xlim(frameTime - 0.1, frameTime + biteWindow + 0.1)
        #         wl = avgAx.get_xticklabels(which='major')
        #         windAx.set_xticklabels([label.get_text() for label in wl], fontsize=axLabelFontSize, fontname=axLabelFontName, fontweight='light')

        fullRect = Rectangle((frameTime, 0),
                             biteWindow,
                             3,
                             fill=None,
                             edgecolor='k',
                             linewidth=2)
        fullAx.add_patch(fullRect)

        windRect = Rectangle((frameTime, 0),
                             biteWindow,
                             3,
                             fill=None,
                             edgecolor='k',
                             linewidth=2)
        windAx.add_patch(windRect)
        bdl = '%03d ms' % int(1000 * avgBiteDur[lastBiteIdx])
        avl = avgAx.plot([frameTime, frameTime], [0, 1],
                         color='red',
                         linewidth=2,
                         label=bdl)
        avgAx.set_ylim(avgAxYLim)
        avleg = avgAx.legend(loc='lower right', fontsize=axLabelFontSize - 2)
        if allBiteTimes[lastBiteIdx, 0] <= frameTime:
            if lastBiteIdx < len(allBiteTimes[:, 0]) - 1:
                lastBiteIdx += 1

        canvas.draw()  # draw the canvas, cache the renderer

        s, (width, height) = canvas.print_to_buffer()
        plotIm = np.fromstring(canvas.tostring_rgb(),
                               dtype='uint8').reshape(height, width, 3)
        avl[0].remove()
        fullRect.remove()
        windRect.remove()
        del avl
        avleg.remove()
        #plotImArr = np.array(plotFig.canvas.renderer.buffer_rgba())
        if fidx == 0:
            plt.figure()
            plt.imshow(plotIm)

        #plotIm = np.array(plotImObj)

        if numVids == 0:
            outputFrame[:, :, 0] = plotIm[:, :, 2]
            outputFrame[:, :, 1] = plotIm[:, :, 1]
            outputFrame[:, :, 2] = plotIm[:, :, 0]
        else:

            if not hasFront:
                frontIm = Image.new('RGB', (wim, him))
                drawObjF = ImageDraw.Draw(frontIm)
            else:
                imfPath = os.path.join(
                    movInfo['FrontMovie']['Folder'],
                    movInfo['FrontMovie']['ImageNames'][fidx])
                imfGray = Image.open(imfPath)
                imfGray = imfGray.crop(frontWindow)
                imfGray = imfGray.rotate(180)
                imfGray = imfGray.resize((wim, him))
                frontIm = imfGray.convert('RGB')
                drawObjF = ImageDraw.Draw(frontIm)

                #draw timestamp
                ftf = movInfo['FrontMovie']['Timestamps'][fidx]
                ftfStr = formatTimeStr(ftf)
                textWidth, textHeight = drawObjF.textsize(ftfStr, font=font)
                timestampPos = (wim - textWidth, 0)
                drawObjF.text(timestampPos,
                              ftfStr,
                              fill=timestampCol,
                              font=font)

            if not hasSide:
                sideIm = Image.new('RGB', (wim, him))
                drawObjS = ImageDraw.Draw(sideIm)
            else:
                imsPath = os.path.join(
                    movInfo['SideMovie']['Folder'],
                    movInfo['SideMovie']['ImageNames'][fidx])
                imsGray = Image.open(imsPath)
                imsGray = imsGray.crop(sideWindow)
                imsGray = imsGray.rotate(180)
                imsGray = imsGray.resize((wim, him))
                sideIm = imsGray.convert('RGB')
                drawObjS = ImageDraw.Draw(sideIm)

                #draw timestamp
                fts = movInfo['SideMovie']['Timestamps'][fidx]
                ftsStr = formatTimeStr(fts)
                textWidth, textHeight = drawObjS.textsize(ftsStr, font=font)
                timestampPos = (wim - textWidth, 0)
                drawObjS.text(timestampPos,
                              ftsStr,
                              fill=timestampCol,
                              font=font)

            #draw biting/water
            biting = movInfo['BiteFrames'][fidx]
            deliveringReward = movInfo['WaterFrames'][fidx]
            if biting:
                drawObjF.text((0, 0), 'BITING', fill=biteCol, font=font)
            if deliveringReward:
                drawObjF.text((0, fontSize + 1),
                              'WATER',
                              fil=waterCol,
                              font=font)

            frontOut = np.array(frontIm).astype(np.uint8)
            sideOut = np.array(sideIm).astype(np.uint8)
            outputFrame[him:hf, :, 0] = plotIm[:, :, 2]
            outputFrame[him:hf, :, 1] = plotIm[:, :, 1]
            outputFrame[him:hf, :, 2] = plotIm[:, :, 0]
            outputFrame[0:him, 0:wim, :] = frontOut
            outputFrame[0:him, wim:wf, :] = sideOut

        cv2.imshow("OutputFrame", outputFrame)
        writer.write(outputFrame)
        # show the frames
        key = cv2.waitKey(30)
        # if the `q` key was pressed, break from the loop
        if key == 27:
            break

    # do a bit of cleanup
    print("[INFO] cleaning up...")
    cv2.destroyAllWindows()
    writer.release()
Ejemplo n.º 16
0
class usb1Windows(QWidget):
    def __del__(self):
        if hasattr(self, "camera"):
            self.camera.release()  # 释放资源

    def init_fun(self):
        self.window = Ui_Form()
        self.window.setupUi(self)

        self.timer = QTimer()  # 定义一个定时器对象
        self.timer.timeout.connect(self.timer_fun)  #计时结束调用方法

        # 1. open usb and show
        self.window.pushButton_2.clicked.connect(self.timer_start)
        # 2. catch one picture
        self.window.pushButton.clicked.connect(self.catch_picture)

        self.window.comboBox.currentIndexChanged.connect(
            self.set_width_and_height)
        self.window.checkBox.clicked.connect(self.get_faces_flag_fun)

        self.window.pushButton_5.clicked.connect(self.preview_picture)
        self.window.pushButton_4.clicked.connect(self.save_picture)

        self.window.pic_figure.canvas.mpl_connect('button_press_event',
                                                  self.on_press)
        self.window.pic_figure.canvas.mpl_connect('button_release_event',
                                                  self.on_release)

        self.getface_flag = False

        fm = open("./identiffun/faces.conf", 'r')
        self.names = fm.read().split(";")
        fm.close()
        self.my_get_face = Get_Faces(self.names)

    def on_press(self, event):
        self.on_x0 = event.xdata
        self.on_y0 = event.ydata
        if not hasattr(self, "rectload"):
            self.rectload = Rectangle((0, 0),
                                      0,
                                      0,
                                      linestyle='solid',
                                      fill=False,
                                      edgecolor='red')
            self.window.pic_figaxes.add_patch(self.rectload)

    def on_release(self, event):
        self.on_x1 = event.xdata
        self.on_y1 = event.ydata
        x_start = int(min(self.on_x0, self.on_x1))
        x_end = int(max(self.on_x0, self.on_x1))
        y_start = int(min(self.on_y0, self.on_y1))
        y_end = int(max(self.on_y0, self.on_y1))
        self.rectload.set_xy((x_start, y_start))
        self.rectload.set_height(y_end - y_start + 1)
        self.rectload.set_width(x_end - x_start + 1)
        self.window.pic_figaxes.figure.canvas.draw()

    def save_picture(self):
        if hasattr(self, 'preview_res'):
            tmp_save_picture = self.preview_res
        else:
            if hasattr(self, 'raw_frame'):
                tmp_save_picture = self.raw_frame
            else:
                return  # no pic
        cv2.imwrite("./image/save.jpg", tmp_save_picture)
        if hasattr(self, "rectload"):
            x, y = self.rectload.get_xy()
            w = self.rectload.get_width()
            h = self.rectload.get_height()
            cv2.imwrite("./image/ret.jpg", tmp_save_picture[y:y + h, x:x + w])
        # filename, filetype = QFileDialog.getSaveFileName(self, "save", "jpg Files(*.jpg)::All Files(*)")
        # if filename:
        #     cv2.imwrite(filename, tmp_save_picture)

    def preview_picture(self):
        if hasattr(self, 'raw_frame'):
            width = self.window.spinBox.value()
            height = self.window.spinBox_2.value()
            # self.raw_frame.reszie((width, height))
            self.preview_res = cv2.resize(self.raw_frame, (width, height),
                                          interpolation=cv2.INTER_CUBIC)
            self.showimg2figaxes2(self.preview_res)

    def get_faces_flag_fun(self):
        if self.window.checkBox.isChecked():
            self.getface_flag = True
        else:
            self.getface_flag = False
        # print(self.getface_flag)

    def set_width_and_height(self):
        # print(self.window.comboBox.currentText())
        width, height = self.window.comboBox.currentText().split('*')
        if hasattr(self, "camera"):
            self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, int(width))
            self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, int(height))

    def catch_picture(self):
        if hasattr(self, "camera") and self.camera.isOpened():
            ret, frame = self.camera.read()
            if ret:
                self.raw_frame = copy.deepcopy(frame)
                if hasattr(self, 'preview_res'):
                    del self.preview_res
                self.showimg2figaxes2(frame)
            else:
                pass  # get faild

    def timer_fun(self):
        ret, frame = self.camera.read()
        if ret:
            self.showimg2figaxes(frame)
        else:
            self.timer.stop()

    def timer_start(self):
        if hasattr(self, "camera"):
            if not self.camera.isOpened():
                self.camera.open(0)
                # self.camera = cv2.VideoCapture(0)
        else:
            self.camera = cv2.VideoCapture(0)
        if self.camera.isOpened():
            pass
        else:
            self.camera.open(0)
        # get
        width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH)
        print(width)
        height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
        print(int(height))
        self.window.comboBox.setCurrentText("%d*%d" %
                                            (int(width), int(height)))

        fps = self.camera.get(cv2.CAP_PROP_FPS)
        if fps == float('inf'):
            pass
        else:
            print(fps)

        brightness = self.camera.get(cv2.CAP_PROP_BRIGHTNESS)
        if brightness == float('inf'):
            self.window.doubleSpinBox_2.setValue(0.0)
        else:
            self.window.doubleSpinBox_2.setValue(brightness)

        contrast = self.camera.get(cv2.CAP_PROP_CONTRAST)
        if contrast == float('inf'):
            self.window.doubleSpinBox.setValue(0.0)
        else:
            self.window.doubleSpinBox.setValue(contrast)

        hue = self.camera.get(cv2.CAP_PROP_HUE)
        if hue == float('inf'):
            self.window.doubleSpinBox_3.setValue(0.0)
        else:
            self.window.doubleSpinBox_3.setValue(hue)

        exposure = self.camera.get(cv2.CAP_PROP_EXPOSURE)
        if exposure == float('inf'):
            self.window.doubleSpinBox_4.setValue(0.0)
        else:
            self.window.doubleSpinBox_4.setValue(exposure)  # inf

        saturation = self.camera.get(cv2.CAP_PROP_SATURATION)
        if saturation == float('inf'):
            self.window.doubleSpinBox_5.setValue(0.0)
        else:
            self.window.doubleSpinBox_5.setValue(saturation)  # inf

        self.timer.start(101)  #设置计时间隔并启动

    def showimg2figaxes2(self, frame):
        b, g, r = cv2.split(frame)
        imgret = cv2.merge([r, g, b])
        if hasattr(self, "rectload"):
            self.rectload.remove()
            del self.rectload
        self.window.pic_figaxes.clear()
        self.window.pic_figaxes.imshow(imgret)
        self.window.pic_figure.canvas.draw()

    def showimg2figaxes(self, img):
        if self.getface_flag:
            tmp_img = self.my_get_face.get_face_fun(img)
        else:
            tmp_img = img
        b, g, r = cv2.split(tmp_img)
        imgret = cv2.merge([r, g, b])  # 这个就是前面说书的,OpenCV和matplotlib显示不一样,需要转换
        self.window.video_figaxes.clear()
        self.window.video_figaxes.imshow(imgret)
        self.window.video_figure.canvas.draw()
Ejemplo n.º 17
0
class WindowSelectionRectangle(object):
    def __init__(self, event, axis, on_window_selection_callback):
        self.axis = axis
        if event.inaxes != self.axis:
            return
        # Store the axes it has been initialized in.
        self.axes = event.inaxes
        ymin, ymax = self.axes.get_ylim()
        self.min_x = event.xdata
        self.intial_selection_active = True
        self.rect = Rectangle((event.xdata, ymin), 0, ymax - ymin,
                              facecolor="0.3", alpha=0.5, edgecolor="0.5")
        self.axes.add_patch(self.rect)
        # Get the canvas.
        self.canvas = self.rect.figure.canvas

        # Use blittig for fast animations.
        self.rect.set_animated(True)
        self.background = self.canvas.copy_from_bbox(self.rect.axes.bbox)

        self._connect()

        self.on_window_selection_callback = on_window_selection_callback

    def _connect(self):
        """
        Connect to the necessary events.
        """
        self.conn_button_press = self.rect.figure.canvas.mpl_connect(
            'button_press_event', self.on_button_press)
        self.conn_button_release = self.rect.figure.canvas.mpl_connect(
            'button_release_event', self.on_button_release)
        self.conn_mouse_motion = self.rect.figure.canvas.mpl_connect(
            'motion_notify_event', self.on_mouse_motion)

    def on_button_press(self, event):
        pass

    def on_button_release(self, event):
        if event.inaxes != self.axis:
            return

        if event.button != 1:
            return
        # turn off the rect animation property and reset the background
        self.rect.set_animated(False)
        self.background = None

        self.intial_selection_active = False
        self.canvas.draw()

        x = self.rect.get_x()
        width = self.rect.get_width()

        if width < 0:
            x = x + width
            width = abs(width)

        self.rect.remove()
        self.on_window_selection_callback(x, width, self.axis)

    def on_mouse_motion(self, event):
        if event.button != 1 or \
                self.intial_selection_active is not True:
            return
        if event.xdata is not None:
            self.rect.set_width(event.xdata - self.min_x)

        # restore the background region
        self.canvas.restore_region(self.background)
        # redraw just the current rectangle
        self.axes.draw_artist(self.rect)
        # blit just the redrawn area
        self.canvas.blit(self.axes.bbox)
Ejemplo n.º 18
0
class RectangleInteractor(QObject):

    epsilon = 5
    showverts = True
    mySignal = pyqtSignal(str)
    modSignal = pyqtSignal(str)
    
    def __init__(self,ax,corner,width,height=None,angle=0.):
        super().__init__()
        from matplotlib.patches import Rectangle
        from matplotlib.lines import Line2D
        # from matplotlib.artist import Artist
        # To avoid crashing with maximum recursion depth exceeded
        import sys
        sys.setrecursionlimit(10000) # 10000 is 10x the default value

        if height is None:
            self.type = 'Square'
            height = width
        else:
            self.type = 'Rectangle'
        self.ax = ax
        self.angle  = angle/180.*np.pi
        self.width  = width
        self.height = height
        self.rect = Rectangle(corner,width,height,edgecolor='Lime',facecolor='none',angle=angle,fill=False,animated=True)
        self.ax.add_patch(self.rect)
        self.canvas = self.rect.figure.canvas
        x,y = self.compute_markers()
        self.line = Line2D(x, y, marker='o', linestyle=None, linewidth=0., markerfacecolor='g', animated=True)
        self.ax.add_line(self.line)
        self.cid = self.rect.add_callback(self.rectangle_changed)
        self._ind = None  # the active point
        self.connect()
        self.aperture = self.rect
        self.press = None
        self.lock = None

    def compute_markers(self):

        theta0 = self.rect.angle / 180.*np.pi
        w0 = self.rect.get_width()
        h0 = self.rect.get_height()
        x0,y0 = self.rect.get_xy()
        c, s = np.cos(-theta0), np.sin(-theta0)
        R = np.matrix('{} {}; {} {}'.format(c, s, -s, c))

        x = [0.5*w0, w0, 0.5*w0]
        y = [0.5*h0, 0.5*h0, h0]

        self.xy = []
        x_ = []
        y_ = []
        for dx,dy in zip(x,y):
            (dx_,dy_), = np.array(np.dot(R,np.array([dx,dy])))
            self.xy.append((dx_+x0,dy_+y0))
            x_.append(dx_+x0)
            y_.append(dy_+y0)

        return x_,y_

    def connect(self):
        self.cid_draw = self.canvas.mpl_connect('draw_event', self.draw_callback)
        self.cid_press = self.canvas.mpl_connect('button_press_event', self.button_press_callback)
        self.cid_release = self.canvas.mpl_connect('button_release_event', self.button_release_callback)
        self.cid_motion = self.canvas.mpl_connect('motion_notify_event', self.motion_notify_callback)
        self.cid_key = self.canvas.mpl_connect('key_press_event', self.key_press_callback)
        self.canvas.draw_idle()
        
    def disconnect(self):
        self.canvas.mpl_disconnect(self.cid_draw)
        self.canvas.mpl_disconnect(self.cid_press)
        self.canvas.mpl_disconnect(self.cid_release)
        self.canvas.mpl_disconnect(self.cid_motion)
        self.canvas.mpl_disconnect(self.cid_key)
        self.rect.remove()
        self.line.remove()
        self.canvas.draw_idle()
        self.aperture = None
        
    def draw_callback(self, event):
        self.background = self.canvas.copy_from_bbox(self.ax.bbox)
        self.ax.draw_artist(self.rect)
        self.ax.draw_artist(self.line)

    def rectangle_changed(self, rect):
        'this method is called whenever the polygon object is called'
        # only copy the artist props to the line (except visibility)
        vis = self.line.get_visible()
        Artist.update_from(self.line, rect)
        self.line.set_visible(vis)  
        
    def get_ind_under_point(self, event):
        'get the index of the point if within epsilon tolerance'

        x, y = zip(*self.xy)
        d = np.hypot(x - event.xdata, y - event.ydata)
        indseq, = np.nonzero(d == d.min())
        ind = indseq[0]

        if d[ind] >= self.epsilon:
            ind = None

        return ind

    def button_press_callback(self, event):
        'whenever a mouse button is pressed'
        if not self.showverts:
            return
        if event.inaxes is None:
            return
        if event.button != 1:
            return
        self._ind = self.get_ind_under_point(event)
        x0, y0 = self.rect.get_xy()
        w0, h0 = self.rect.get_width(), self.rect.get_height()
        theta0 = self.rect.angle/180*np.pi
        self.press = x0, y0, w0, h0, theta0, event.xdata, event.ydata
        self.xy0 = self.xy

        self.lock = "pressed"

    def key_press_callback(self, event):
        'whenever a key is pressed'
        if not event.inaxes:
            return
        if event.key == 't':
            self.showverts = not self.showverts
            self.line.set_visible(self.showverts)
            if not self.showverts:
                self._ind = None
        elif event.key == 'd':
            #self.disconnect()
            #self.rect = None
            #self.line = None
            self.mySignal.emit('rectangle deleted')
        self.canvas.draw_idle()

    def button_release_callback(self, event):
        'whenever a mouse button is released'
        if not self.showverts:
            return
        if event.button != 1:
            return
        self._ind = None
        self.press = None
        self.lock = "released"
        self.background = None
        # To get other aperture redrawn
        self.canvas.draw_idle()

    def motion_notify_callback(self, event):
        'on mouse movement'
        if not self.showverts:
            return
        if self._ind is None:
            return
        if event.inaxes is None:
            return
        if event.button != 1:
            return
        x0, y0, w0, h0, theta0, xpress, ypress = self.press
        self.dx = event.xdata - xpress
        self.dy = event.ydata - ypress
        self.update_rectangle()

        # Redraw rectangle and points
        self.canvas.restore_region(self.background)
        self.ax.draw_artist(self.rect)
        self.ax.draw_artist(self.line)
        self.canvas.update()
        self.canvas.flush_events()

        # Notify callback
        self.modSignal.emit('rectangle modified')

    def update_rectangle(self):

        x0, y0, w0, h0, theta0, xpress, ypress = self.press
        dx, dy = self.dx, self.dy
        
        if self.lock == "pressed":
            if self._ind == 0:
                self.lock = "move"
            else:
                self.lock = "resizerotate"
        elif self.lock == "move":
            if x0+dx < 0:
                xn = x0
                dx = 0
            else:
                xn = x0+dx
            if y0+dy < 0:
                yn = y0
                dy = 0
            else:
                yn = y0+dy
            self.rect.set_xy((xn,yn))
            # update line
            self.xy = [(i+dx,j+dy) for (i,j) in self.xy0]
            # Redefine line
            self.line.set_data(zip(*self.xy))
        # otherwise rotate and resize
        elif self.lock == 'resizerotate':
            xc,yc = self.xy0[0] # center is conserved in the markers
            dtheta = np.arctan2(ypress+dy-yc,xpress+dx-xc)-np.arctan2(ypress-yc,xpress-xc)
            theta_ = (theta0+dtheta) * 180./np.pi
            c, s = np.cos(theta0), np.sin(theta0)
            R = np.matrix('{} {}; {} {}'.format(c, s, -s, c))
            (dx_,dy_), = np.array(np.dot(R,np.array([dx,dy])))

            # Avoid to pass through the center            
            if self._ind == 1:
                w_ = w0+2*dx_  if (w0+2*dx_) > 0 else w0
                if self.type == 'Square':
                    h_ = w_
                else:
                    h_ = h0
            elif self._ind == 2:
                h_ = h0+2*dy_  if (h0+2*dy_) > 0 else h0
                if self.type == 'Square':
                    w_ = h_
                else:
                    w_ = w0
            # update rectangle
            self.rect.set_width(w_)
            self.rect.set_height(h_)
            self.rect.angle = theta_
            # update markers
            self.updateMarkers()

    def updateMarkers(self):
        # update points
        x,y = self.compute_markers()
        self.line.set_data(x,y)
Ejemplo n.º 19
0
class MplWidget(QtWidgets.QWidget):
    def __init__(self, parent=None):
        QWidget.__init__(self, parent)
        self.scroll = QtWidgets.QScrollArea(self)
        self.scroll.setParent(None)
        #self.fig =Figure(tight_layout=True)
        self.fig = Figure()
        left = 0.0
        bottom = 0.0
        width = 1
        height = 1
        self.fig.add_axes([left, bottom, width, height])
        self.canvas = FigureCanvas(self.fig)
        self.fig.set_facecolor([0.23, 0.23, 0.23, 0.5])
        self.canvas.axes = self.canvas.figure.gca()

        #self.canvas.figure.tight_layout(pad=0)
        self.vertical_layout = QVBoxLayout()
        self.vertical_layout.addWidget(self.canvas)
        self.mpl_toolbar = my_toolbar(self.canvas, self)
        self.mpl_toolbar.setParentClass(self)
        self.mpl_toolbar.setMinimumWidth(100)

        self.mpl_toolbar.setFixedHeight(26)
        self.mpl_toolbar.setStyleSheet(
            "QToolBar { opacity: 1;border: 0px; background-color: rgb(133, 196, 65); border-bottom: 1px solid #19232D;padding: 2px;  font-weight: bold;spacing: 2px; } "
        )
        self.mpl_toolbar.setObjectName("myToolBar")

        #self.canvas.mpl_connect("resize_event", self.resize)
        self.vertical_layout.addWidget(self.mpl_toolbar)
        self.setLayout(self.vertical_layout)
        self.layout().setContentsMargins(0, 0, 0, 0)
        self.layout().setSpacing(0)
        self.rect = Rectangle((0, 0), 1, 1)
        self.updateSecondImage = None
        self.patchesTotal = 0
        self.typeOfAnnotation = "autoDetcted"
        self.frameAtString = "Frame 0"
        self.currentSelectedOption = None

        self.AllBoxListDictionary = {
            "eraseBox": [],
            "oneWormLive": [],
            "multiWormLive": [],
            "oneWormDead": [],
            "multiWormDead": [],
            "miscBoxes": [],
            "autoDetcted": []
        }

        self.eraseBoxXYValues = self.AllBoxListDictionary["eraseBox"]
        self.addBoxXYValues = self.AllBoxListDictionary["miscBoxes"]
        self.oneWormLiveBoxXYValues = self.AllBoxListDictionary["oneWormLive"]
        self.multiWormLiveBoxXYValues = self.AllBoxListDictionary[
            "multiWormLive"]
        self.oneWormDeadBoxXYValues = self.AllBoxListDictionary["oneWormDead"]
        self.multiWormDeadBoxXYValues = self.AllBoxListDictionary[
            "multiWormDead"]
        self.autoDetectedBoxXYValues = self.AllBoxListDictionary["autoDetcted"]
        self.tempList = []

    def resetAllBoxListDictionary(self):
        self.AllBoxListDictionary = {
            "eraseBox": [],
            "oneWormLive": [],
            "multiWormLive": [],
            "oneWormDead": [],
            "multiWormDead": [],
            "miscBoxes": [],
            "autoDetcted": []
        }

    def updateAllBoxListDictionary(self):
        self.AllBoxListDictionary["eraseBox"] = self.eraseBoxXYValues
        self.AllBoxListDictionary["miscBoxes"] = self.addBoxXYValues
        self.AllBoxListDictionary["oneWormLive"] = self.oneWormLiveBoxXYValues
        self.AllBoxListDictionary[
            "multiWormLive"] = self.multiWormLiveBoxXYValues
        self.AllBoxListDictionary["oneWormDead"] = self.oneWormDeadBoxXYValues
        self.AllBoxListDictionary[
            "multiWormDead"] = self.multiWormDeadBoxXYValues
        self.AllBoxListDictionary["autoDetcted"] = self.autoDetectedBoxXYValues

    def updateAllListFromAllBoxListDictionary(self):
        self.eraseBoxXYValues = self.AllBoxListDictionary["eraseBox"]
        self.addBoxXYValues = self.AllBoxListDictionary["miscBoxes"]
        self.oneWormLiveBoxXYValues = self.AllBoxListDictionary["oneWormLive"]
        self.multiWormLiveBoxXYValues = self.AllBoxListDictionary[
            "multiWormLive"]
        self.oneWormDeadBoxXYValues = self.AllBoxListDictionary["oneWormDead"]
        self.multiWormDeadBoxXYValues = self.AllBoxListDictionary[
            "multiWormDead"]
        self.autoDetectedBoxXYValues = self.AllBoxListDictionary["autoDetcted"]

    def setFrameAtString(self, text):
        self.frameAtString = text

    def getFrameAtString(self):
        return self.frameAtString

    def getCurrentSelectedOption(self):
        return self.currentSelectedOption

    def setCurrentSelectedOption(self, option):
        self.currentSelectedOption = option

    def setDarkTheme(self):
        self.mpl_toolbar.setStyleSheet(
            "QToolBar#myToolBar{ border: 0px; background-color: rgb(133, 0,s 65); border-bottom: 1px solid #19232D;padding: 2px;  font-weight: bold;spacing: 2px; } "
        )
        self.fig.set_facecolor([0.23, 0.23, 0.23, 0.5])
        #self.fig.set_facecolor('grey')
        self.canvas.draw()

    def setGreenTheme(self):
        self.mpl_toolbar.setStyleSheet(
            "QToolBar { border: 0px; background-color: rgb(133, 196, 65); border-bottom: 1px solid #19232D;padding: 2px;  font-weight: bold;spacing: 2px; } "
        )
        self.fig.set_facecolor('grey')
        self.canvas.draw()

    def setTypeOfAnnotation(self, text):
        self.typeOfAnnotation = text

    def restrictCanvasMinimumSize(self, size):
        self.canvas.setMinimumSize(size)

    def unmountWidgetAndClear(self):
        self.vertical_layout.removeWidget(self.canvas)
        self.vertical_layout.removeWidget(self.scroll)
        self.scroll.setParent(None)
        self.canvas.setParent(None)
        sip.delete(self.scroll)
        del self.canvas
        self.scroll = None
        self.canvas = None
        self.canvas = FigureCanvas(Figure())
        self.canvas.axes = self.canvas.figure.gca()
        #self.canvas.figure.tight_layout()
        self.scroll = QtWidgets.QScrollArea(self)
        self.scroll.setWidgetResizable(True)

    def connectClickListnerToCurrentImageForCrop(self,
                                                 givenController,
                                                 updateSecondImage=None,
                                                 listOfControllers=None,
                                                 keyForController=None):
        self.cid1 = self.canvas.mpl_connect("button_press_event",
                                            self.on_press_for_crop)
        self.cid2 = self.canvas.mpl_connect("motion_notify_event",
                                            self.onmove_for_crop)
        self.cid3 = self.canvas.mpl_connect("button_release_event",
                                            self.on_release_for_crop)
        self.givenControllerObject = givenController
        self.updateSecondImage = updateSecondImage
        self.pressevent = None
        self.listOfControllers = listOfControllers
        self.keyForController = keyForController

    def on_press_for_crop(self, event):
        if (self.mpl_toolbar.mode):
            return

        try:
            self.rect.remove()
        except:
            pass
        self.addedPatch = None
        self.x0 = event.xdata
        self.y0 = event.ydata
        self.rect = Rectangle((self.x0, self.y0), 1, 1)
        self.rect._alpha = 0.5
        self.rect._linewidth = 2
        self.rect.set_color("C2")
        self.rect.set
        self.pressevent = 1
        self.addedPatch = self.canvas.axes.add_patch(self.rect)

    def on_release_for_crop(self, event):
        if (self.mpl_toolbar.mode):
            return

        self.pressevent = None

        minMaxVertices = [
            int(np.ceil(min(self.x0, self.x1))),
            int(np.ceil(min(self.y0, self.y1))),
            int(np.round(max(self.x0, self.x1))),
            int(np.round(max(self.y0, self.y1))),
        ]
        self.givenControllerObject.updateManualCropCoordinates(minMaxVertices)
        image = self.givenControllerObject.showManualCropImage()
        self.canvas.axes.clear()
        self.canvas.axes.axis("off")
        self.canvas.axes.imshow(image)
        self.canvas.draw()
        if self.updateSecondImage is not None:
            self.updateSecondImage.canvas.axes.clear()
            self.updateSecondImage.canvas.axes.axis("off")
            self.updateSecondImage.canvas.axes.imshow(
                self.givenControllerObject.getCroppedImage(0))
            self.updateSecondImage.canvas.draw()
            self.listOfControllers[
                self.keyForController] = self.givenControllerObject

    def onmove_for_crop(self, event):

        if self.pressevent is None:
            return
        self.x1 = event.xdata
        self.y1 = event.ydata
        self.rect.set_width(self.x1 - self.x0)
        self.rect.set_height(self.y1 - self.y0)
        self.rect.set_xy((self.x0, self.y0))
        self.canvas.draw()

    def disconnectClickListnerFromCurrentImageForCrop(self):
        try:
            self.canvas.mpl_disconnect(self.cid1)
            self.canvas.mpl_disconnect(self.cid2)
            self.canvas.mpl_disconnect(self.cid3)
            self.updateSecondImage = None
        except:
            pass

    def getCurrentScrollParam(self):
        self.currentVerticalSliderValue = self.scroll.verticalScrollBar(
        ).value()
        self.currentHorizontalSliderValue = self.scroll.horizontalScrollBar(
        ).value()

    def resetCurrentScrollParam(self):
        self.scroll.verticalScrollBar().setValue(
            self.currentVerticalSliderValue)
        self.scroll.horizontalScrollBar().setValue(
            self.currentHorizontalSliderValue)

    def resize(self, event):
        # on resize reposition the navigation toolbar to (0,0) of the axes.
        x, y = self.fig.axes[0].transAxes.transform((0, 0))
        figw, figh = self.fig.get_size_inches()
        ynew = figh * self.fig.dpi - y - self.mpl_toolbar.frameGeometry(
        ).height()
        self.mpl_toolbar.move(x, ynew)

    def connectClickListnerToCurrentImageForAnnotate(self,
                                                     givenController,
                                                     updateSecondImage=None):
        self.cid4 = self.canvas.mpl_connect("button_press_event",
                                            self.on_press_for_annotate)

        self.cid7 = self.canvas.mpl_connect('pick_event', self.onpick)
        #self.cid7 = self.canvas.mpl_connect('button_press_event', self.right_click_press_for_annotate)
        self.givenControllerObject = givenController
        self.updateSecondImage = updateSecondImage
        self.pressevent = None

    def autoAnnotateOnOverlay(self, autoDetectedObjects):

        for index, row in autoDetectedObjects.iterrows():
            print(row.bbox3)

            #if self.pressevent is None:
            #    return
            #self.x1 = event.xdata
            #self.y1 = event.ydata
            self.rect.set_width(row.bbox3 - row.bbox1)
            self.rect.set_height(row.bbox2 - row.bbox0)
            self.rect.set_xy((row.bbox1, row.bbox0))

            self.canvas.draw()

            self.rect = Rectangle((row.bbox1, row.bbox0), 1, 1, picker=True)
            self.rect._alpha = 1
            self.rect._edgecolor = (0, 1, 0, 1)
            self.rect._facecolor = (0, 0, 0, 0)

            self.rect._linewidth = 1
            self.rect.set_linestyle('dashed')
            self.rect.addName = self.typeOfAnnotation
            self.pressevent = 1
            self.canvas.axes.add_patch(self.rect)
            self.patchesTotal = self.patchesTotal + 1

            if [row.bbox1, row.bbox0, row.bbox3,
                    row.bbox2] not in self.autoDetectedBoxXYValues:
                self.autoDetectedBoxXYValues.append(
                    [row.bbox1, row.bbox0, row.bbox3, row.bbox2])

            # Update latest values
            self.updateAllBoxListDictionary()
            #print(self.typeOfAnnotation)
            '''if self.typeOfAnnotation == "eraseBox":
                if [self.x0, self.y0, self.x1, self.y1] not in self.tempList:
                    self.tempList.append([self.x0, self.y0, self.x1, self.y1])

            if self.typeOfAnnotation not in ["eraseBox", "oneWormLive", "multiWormLive", "oneWormDead", "multiWormDead"]:
                if [self.x0, self.y0, self.x1, self.y1] not in self.tempList:
                    self.tempList.append([self.x0, self.y0, self.x1, self.y1])

            if self.typeOfAnnotation == "oneWormLive":
                if [self.x0, self.y0, self.x1, self.y1] not in self.tempList:
                    self.tempList.append([self.x0, self.y0, self.x1, self.y1])

            if self.typeOfAnnotation == "multiWormLive":
                if [self.x0, self.y0, self.x1, self.y1] not in self.tempList:
                    self.tempList.append([self.x0, self.y0, self.x1, self.y1])

            if self.typeOfAnnotation == "oneWormDead":
                if [self.x0, self.y0, self.x1, self.y1] not in self.tempList:
                    self.tempList.append([self.x0, self.y0, self.x1, self.y1])

            if self.typeOfAnnotation == "multiWormDead":
                if [self.x0, self.y0, self.x1, self.y1] not in self.tempList:
                    self.tempList.append([self.x0, self.y0, self.x1, self.y1])'''

            #self.canvas.draw()

        #return(self.canvas)

    def on_press_for_annotate(self, event):
        # try:
        #     self.rect.remove()
        # except:
        #     pass
        if (self.mpl_toolbar.mode):
            return

        if event.button == 1:
            self.cid5 = self.canvas.mpl_connect("motion_notify_event",
                                                self.onmove_for_annotate)
            self.cid6 = self.canvas.mpl_connect("button_release_event",
                                                self.on_release_for_annotate)

            self.x0 = event.xdata
            self.y0 = event.ydata

            self.rect = Rectangle((self.x0, self.y0), 1, 1, picker=True)
            self.rect._alpha = 1
            if self.typeOfAnnotation not in [
                    "eraseBox", "oneWormLive", "multiWormLive", "oneWormDead",
                    "multiWormDead"
            ]:
                self.rect._edgecolor = (0, 1, 0, 1)
                self.rect._facecolor = (0, 0, 0, 0)
            elif self.typeOfAnnotation == "autoDetcted":
                self.rect._edgecolor = (0, 1, 0, 1)
                self.rect._facecolor = (0, 0, 0, 0)
            elif self.typeOfAnnotation == "eraseBox":
                self.rect._edgecolor = (0, 0, 0, 1)
                self.rect._facecolor = (0, 0, 0, 0)
            elif self.typeOfAnnotation == "oneWormLive":
                self.rect._edgecolor = (0, 0, 1, 1)
                self.rect._facecolor = (0, 0, 0, 0)
            elif self.typeOfAnnotation == "multiWormLive":
                self.rect._edgecolor = (1, 1, 0, 1)
                self.rect._facecolor = (0, 0, 0, 0)
            elif self.typeOfAnnotation == "oneWormDead":
                self.rect._edgecolor = (1, 0, 0, 1)
                self.rect._facecolor = (0, 0, 0, 0)
            elif self.typeOfAnnotation == "multiWormDead":
                self.rect._edgecolor = (1, 1, 1, 1)
                self.rect._facecolor = (0, 0, 0, 0)

            self.rect._linewidth = 1
            self.rect.set_linestyle('dashed')
            self.rect.addName = self.typeOfAnnotation
            self.pressevent = 1
            self.canvas.axes.add_patch(self.rect)
            self.patchesTotal = self.patchesTotal + 1

    def on_release_for_annotate(self, event):
        if (self.mpl_toolbar.mode):
            return

        if event.button == 1:
            self.canvas.mpl_disconnect(self.cid5)
            if (self.rect.get_height() == 1) and (self.rect.get_width() == 1):
                self.rect.remove()
            self.pressevent = None
            self.canvas.mpl_disconnect(self.cid6)

        if self.typeOfAnnotation == "eraseBox":
            #print(self.typeOfAnnotation)
            self.eraseBoxXYValues.append(self.tempList[-1])
            self.tempList = []

        if self.typeOfAnnotation not in [
                "eraseBox", "oneWormLive", "multiWormLive", "oneWormDead",
                "multiWormDead"
        ]:
            #print(self.typeOfAnnotation)
            self.addBoxXYValues.append(self.tempList[-1])
            self.tempList = []

        if self.typeOfAnnotation == "oneWormLive":
            self.oneWormLiveBoxXYValues.append(self.tempList[-1])
            self.tempList = []

        if self.typeOfAnnotation == "multiWormLive":
            self.multiWormLiveBoxXYValues.append(self.tempList[-1])
            self.tempList = []

        if self.typeOfAnnotation == "oneWormDead":
            self.oneWormDeadBoxXYValues.append(self.tempList[-1])
            self.tempList = []

        if self.typeOfAnnotation == "multiWormDead":
            self.multiWormDeadBoxXYValues.append(self.tempList[-1])
            self.tempList = []

        # updateAllBoxListDictionary(self)
        self.updateAllBoxListDictionary()

        # self.givenControllerObject.updateManualCropCoordinates(minMaxVertices)
        # image = self.givenControllerObject.showManualCropImage()
        # self.canvas.axes.clear()
        # self.canvas.axes.axis("off")
        # self.canvas.axes.imshow(image)
        # self.canvas.draw()
        # if self.updateSecondImage is not None:
        #     self.updateSecondImage.canvas.axes.clear()
        #     self.updateSecondImage.canvas.axes.axis("off")
        #     self.updateSecondImage.canvas.axes.imshow(self.givenControllerObject.getCroppedImage(0))
        #     self.updateSecondImage.canvas.draw()

    def onmove_for_annotate(self, event):

        if self.pressevent is None:
            return
        self.x1 = event.xdata
        self.y1 = event.ydata
        self.rect.set_width(self.x1 - self.x0)
        self.rect.set_height(self.y1 - self.y0)
        self.rect.set_xy((self.x0, self.y0))

        #print(self.typeOfAnnotation)
        if self.typeOfAnnotation == "eraseBox":
            if [self.x0, self.y0, self.x1, self.y1] not in self.tempList:
                self.tempList.append([self.x0, self.y0, self.x1, self.y1])

        if self.typeOfAnnotation not in [
                "eraseBox", "oneWormLive", "multiWormLive", "oneWormDead",
                "multiWormDead"
        ]:
            if [self.x0, self.y0, self.x1, self.y1] not in self.tempList:
                self.tempList.append([self.x0, self.y0, self.x1, self.y1])

        if self.typeOfAnnotation == "oneWormLive":
            if [self.x0, self.y0, self.x1, self.y1] not in self.tempList:
                self.tempList.append([self.x0, self.y0, self.x1, self.y1])

        if self.typeOfAnnotation == "multiWormLive":
            if [self.x0, self.y0, self.x1, self.y1] not in self.tempList:
                self.tempList.append([self.x0, self.y0, self.x1, self.y1])

        if self.typeOfAnnotation == "oneWormDead":
            if [self.x0, self.y0, self.x1, self.y1] not in self.tempList:
                self.tempList.append([self.x0, self.y0, self.x1, self.y1])

        if self.typeOfAnnotation == "multiWormDead":
            if [self.x0, self.y0, self.x1, self.y1] not in self.tempList:
                self.tempList.append([self.x0, self.y0, self.x1, self.y1])

        self.canvas.draw()

    def getEraseBoxXYValues(self):
        return (self.eraseBoxXYValues)

    def getAutoDetctedBoxXYValues(self):
        return (self.autoDetectedBoxXYValues)

    def getAddBoxXYValues(self):
        return (self.addBoxXYValues)

    def getOneWormLiveBoxXYValues(self):
        return (self.oneWormLiveBoxXYValues)

    def getMultiWormLiveBoxXYValues(self):
        return (self.multiWormLiveBoxXYValues)

    def getOneWormDeadBoxXYValues(self):
        return (self.oneWormDeadBoxXYValues)

    def getMultiWormDeadBoxXYValues(self):
        return (self.multiWormDeadBoxXYValues)

    def resetEraseBoxXYValues(self):
        self.eraseBoxXYValues = []

    def resetAutoDetctedBoxXYValues(self):
        self.autoDetectedBoxXYValues = []

    def resetAddBoxXYValues(self):
        self.addBoxXYValues = []

    def resetOneWormLiveBoxXYValues(self):
        self.oneWormLiveBoxXYValues = []

    def resetMultiWormLiveBoxXYValues(self):
        self.multiWormLiveBoxXYValues = []

    def resetOneWormDeadBoxXYValues(self):
        self.oneWormDeadBoxXYValues = []

    def resetMultiWormDeadBoxXYValues(self):
        self.multiWormDeadBoxXYValues = []

    def disconnectClickListnerFromCurrentImageForAnnotate(self):
        try:
            self.canvas.mpl_disconnect(self.cid4)

            self.canvas.mpl_disconnect(self.cid7)
            self.updateSecondImage = None
        except:
            pass

    def onpick(self, event):
        #if event.button == 3:       #"3" is the right button
        # print "you click the right button"
        # print 'button=%d, x=%d, y=%d, xdata=%f, ydata=%f'%(
        # event.button, event.x, event.y, event.xdata, event.ydata)
        #Get the coordinates of the mouse click
        #I create the action
        if (self.mpl_toolbar.mode):
            return
        if event.mouseevent.button == 3:
            self.objectPicked = event.artist
            noteAction_1 = QtWidgets.QAction('Delete Box', self)
            noteAction_2 = QtWidgets.QAction('Classify', self)
            #noteAction_5 = QtWidgets.QAction('Add Once',self)
            #noteAction_2 = QtWidgets.QAction('Add Through',self)
            #noteAction_3 = QtWidgets.QAction('Mask Here',self)
            #noteAction_4 = QtWidgets.QAction('Mask Through',self)
            #noteAction_6 = QtWidgets.QAction('Live here',self)
            #noteAction_7 = QtWidgets.QAction('Live all',self)
            #noteAction_8 = QtWidgets.QAction('Dead here',self)
            #noteAction_9 = QtWidgets.QAction('Dead all',self)

            #I create the context menu
            self.popMenu = QtWidgets.QMenu(self)
            self.popMenu.addAction(noteAction_1)
            self.popMenu.addAction(noteAction_2)
            # self.popMenu.addAction(noteAction_2)
            # self.popMenu.addAction(noteAction_3)
            # self.popMenu.addAction(noteAction_4)
            # self.popMenu.addAction(noteAction_5)
            # self.popMenu.addAction(noteAction_6)
            # self.popMenu.addAction(noteAction_7)
            # self.popMenu.addAction(noteAction_8)
            # self.popMenu.addAction(noteAction_9)

            cursor = QtGui.QCursor()
            #self.connect(self.figure_canvas, SIGNAL("clicked()"), self.context_menu)
            #self.popMenu.exec_(self.mapToGlobal(event.globalPos()))
            noteAction_1.triggered.connect(lambda: self.removeThisArea(1))
            noteAction_2.triggered.connect(
                lambda: self.classifyAsCurrentSelection(1))
            # noteAction_2.triggered.connect(lambda :self.removeThisArea(2))
            # noteAction_3.triggered.connect(lambda :self.removeThisArea(3))
            # noteAction_4.triggered.connect(lambda :self.removeThisArea(4))
            # noteAction_5.triggered.connect(lambda :self.removeThisArea(5))
            # noteAction_6.triggered.connect(lambda :self.removeThisArea(5))
            # noteAction_7.triggered.connect(lambda :self.removeThisArea(2))
            # noteAction_8.triggered.connect(lambda :self.removeThisArea(3))
            # noteAction_9.triggered.connect(lambda :self.removeThisArea(4))

            self.popMenu.popup(cursor.pos())
        else:
            return

    def right_click_press_for_annotate(self, event):
        if (self.mpl_toolbar.mode):
            return
        if event.button == 3:  #"3" is the right button
            # print "you click the right button"
            # print 'button=%d, x=%d, y=%d, xdata=%f, ydata=%f'%(
            # event.button, event.x, event.y, event.xdata, event.ydata)
            #Get the coordinates of the mouse click
            #I create the action
            noteAction_1 = QtWidgets.QAction('Remove', self)
            noteAction_2 = QtWidgets.QAction('Add', self)

            #I create the context menu
            self.popMenu = QtWidgets.QMenu(self)
            self.popMenu.addAction(noteAction_1)
            self.popMenu.addAction(noteAction_2)
            cursor = QtGui.QCursor()

            #self.connect(self.figure_canvas, SIGNAL("clicked()"), self.context_menu)
            #self.popMenu.exec_(self.mapToGlobal(event.globalPos()))
            noteAction_1.triggered.connect(
                lambda eventData=object: self.removeThisArea(eventData))
            noteAction_2.triggered.connect(
                lambda eventData=object: self.classifyAsCurrentSelection(
                    eventData))
            self.popMenu.popup(cursor.pos())

    def classifyAsCurrentSelection(self, caseNumber):

        # Get all the list values for this frame
        self.updateAllListFromAllBoxListDictionary()

        print("INSIDE classifyAsCurrentSelection")

        try:
            if caseNumber == 1:  # green delete
                print(type(self.objectPicked))
                X0 = self.objectPicked.get_xy()[0]
                Y0 = self.objectPicked.get_xy()[1]
                X1 = X0 + self.objectPicked.get_width()
                Y1 = Y0 + self.objectPicked.get_height()

                selectedBoxCoords = [X0, Y0, X1, Y1]

                if self.currentSelectedOption == "eraseBox":
                    #self.autoDetectedBoxXYValues.remove(selectedBoxCoords)
                    #self.eraseBoxXYValues.append(selectedBoxCoords)
                    print("Use Delte Option! Right Click -> Delete Box")

                if self.currentSelectedOption == "autoDetcted":
                    #self.autoDetectedBoxXYValues.remove(selectedBoxCoords)
                    #self.addBoxXYValues.append(selectedBoxCoords)
                    print("Already Selected!")

                if self.currentSelectedOption not in [
                        "oneWormLive", "multiWormLive", "oneWormDead",
                        "multiWormDead", "autoDetcted"
                ]:
                    self.autoDetectedBoxXYValues.remove(selectedBoxCoords)
                    self.addBoxXYValues.append(selectedBoxCoords)

                    self.rect.set_width(X1 - X0)
                    self.rect.set_height(Y1 - Y0)
                    self.rect.set_xy((X0, Y0))

                    self.rect = Rectangle((X0, Y0), 1, 1, picker=True)
                    self.rect._alpha = 1
                    self.rect._edgecolor = (0, 1, 0, 1)
                    self.rect._facecolor = (0, 0, 0, 0)

                    self.canvas.draw()

                    self.rect._linewidth = 1
                    self.rect.set_linestyle('dashed')
                    self.rect.addName = self.typeOfAnnotation
                    self.pressevent = 1
                    self.canvas.axes.add_patch(self.rect)

                if self.currentSelectedOption == "oneWormLive" and selectedBoxCoords not in self.oneWormLiveBoxXYValues:

                    self.autoDetectedBoxXYValues.remove(selectedBoxCoords)
                    self.oneWormLiveBoxXYValues.append(selectedBoxCoords)

                    self.canvas.draw()

                    self.rect.set_width(X1 - X0)
                    self.rect.set_height(Y1 - Y0)
                    self.rect.set_xy((X0, Y0))

                    self.rect = Rectangle((X0, Y0), 1, 1, picker=True)
                    self.rect._alpha = 1
                    self.rect._edgecolor = (0, 0, 1, 1)
                    self.rect._facecolor = (0, 0, 0, 0)

                    self.rect._linewidth = 1
                    self.rect.set_linestyle('dashed')
                    self.rect.addName = self.typeOfAnnotation
                    self.pressevent = 1
                    self.canvas.axes.add_patch(self.rect)

                    self.canvas.draw()

                if self.currentSelectedOption == "multiWormLive" and selectedBoxCoords not in self.multiWormLiveBoxXYValues:
                    self.autoDetectedBoxXYValues.remove(selectedBoxCoords)
                    self.multiWormLiveBoxXYValues.append(selectedBoxCoords)

                    self.rect.set_width(X1 - X0)
                    self.rect.set_height(Y1 - Y0)
                    self.rect.set_xy((X0, Y0))

                    self.rect = Rectangle((X0, Y0), 1, 1, picker=True)
                    self.rect._alpha = 1
                    self.rect._edgecolor = (1, 1, 0, 1)
                    self.rect._facecolor = (0, 0, 0, 0)

                    self.canvas.draw()

                    self.rect._linewidth = 1
                    self.rect.set_linestyle('dashed')
                    self.rect.addName = self.typeOfAnnotation
                    self.pressevent = 1
                    self.canvas.axes.add_patch(self.rect)

                if self.currentSelectedOption == "oneWormDead" and selectedBoxCoords not in self.oneWormDeadBoxXYValues:

                    self.autoDetectedBoxXYValues.remove(selectedBoxCoords)
                    self.oneWormDeadBoxXYValues.append(selectedBoxCoords)

                    self.rect.set_width(X1 - X0)
                    self.rect.set_height(Y1 - Y0)
                    self.rect.set_xy((X0, Y0))

                    self.rect = Rectangle((X0, Y0), 1, 1, picker=True)
                    self.rect._alpha = 1
                    self.rect._edgecolor = (1, 0, 0, 1)
                    self.rect._facecolor = (0, 0, 0, 0)

                    self.canvas.draw()

                    self.rect._linewidth = 1
                    self.rect.set_linestyle('dashed')
                    self.rect.addName = self.typeOfAnnotation
                    self.pressevent = 1
                    self.canvas.axes.add_patch(self.rect)

                if self.currentSelectedOption == "multiWormDead" and selectedBoxCoords not in self.multiWormDeadBoxXYValues:
                    self.autoDetectedBoxXYValues.remove(selectedBoxCoords)
                    self.multiWormDeadBoxXYValues.append(selectedBoxCoords)

                    self.rect.set_width(X1 - X0)
                    self.rect.set_height(Y1 - Y0)
                    self.rect.set_xy((X0, Y0))

                    self.rect = Rectangle((X0, Y0), 1, 1, picker=True)
                    self.rect._alpha = 1
                    self.rect._edgecolor = (1, 1, 1, 1)
                    self.rect._facecolor = (0, 0, 0, 0)

                    self.canvas.draw()

                    self.rect._linewidth = 1
                    self.rect.set_linestyle('dashed')
                    self.rect.addName = self.typeOfAnnotation
                    self.pressevent = 1
                    self.canvas.axes.add_patch(self.rect)

        except:
            print("Delete and Redraw!")
        # updateAllBoxListDictionary(self)
        self.updateAllBoxListDictionary()

    def removeThisArea(self, caseNumber):

        # Get all the list values for this frame
        self.updateAllListFromAllBoxListDictionary()

        if caseNumber == 1:  # green delete
            print(type(self.objectPicked))
            X0 = self.objectPicked.get_xy()[0]
            Y0 = self.objectPicked.get_xy()[1]
            X1 = X0 + self.objectPicked.get_width()
            Y1 = Y0 + self.objectPicked.get_height()

            removeBoxCoords = [X0, Y0, X1, Y1]
            #print(removeBoxCoords)
            self.objectPicked.remove()
            self.patchesTotal = self.patchesTotal - 1

            try:
                if removeBoxCoords in self.eraseBoxXYValues:
                    self.eraseBoxXYValues.remove(removeBoxCoords)

                if removeBoxCoords in self.addBoxXYValues:
                    self.addBoxXYValues.remove(removeBoxCoords)

                if removeBoxCoords in self.oneWormLiveBoxXYValues:
                    #print(self.oneWormLiveBoxXYValues)
                    self.oneWormLiveBoxXYValues.remove(removeBoxCoords)
                    #print(self.oneWormLiveBoxXYValues)

                if removeBoxCoords in self.multiWormLiveBoxXYValues:
                    self.multiWormLiveBoxXYValues.remove(removeBoxCoords)

                if removeBoxCoords in self.oneWormDeadBoxXYValues:
                    self.oneWormDeadBoxXYValues.remove(removeBoxCoords)

                if removeBoxCoords in self.multiWormDeadBoxXYValues:
                    self.multiWormDeadBoxXYValues.remove(removeBoxCoords)

                if removeBoxCoords in self.autoDetectedBoxXYValues:
                    print(len(self.autoDetectedBoxXYValues))
                    self.autoDetectedBoxXYValues.remove(removeBoxCoords)
                    print(len(self.autoDetectedBoxXYValues))
            except:
                pass

        # elif caseNumber == 2:     # orange add all
        #     self.objectPicked._facecolor = (1.0, 0.64, 0.0,0.5)
        #     self.objectPicked._alpha  = 0.5
        #     self.objectPicked.addName ="addAll"
        # elif caseNumber == 3:     # black
        #     self.objectPicked._facecolor = (0,0, 0, 0.8)
        #     self.objectPicked._alpha = 0.8
        #     self.objectPicked.addName ="eraseBox"
        # elif caseNumber == 4:
        #     self.objectPicked._facecolor = ( 0, 0, 0, 0.2)
        #     self.objectPicked._alpha = 0.2
        #     self.objectPicked.addName ="deleteAll"
        # elif caseNumber == 5:
        #     self.objectPicked.set_color("C2")
        #     self._edgecolor = (0, 0, 0, 0)
        #     self.objectPicked.addName ="addBox"

        self.canvas.draw()
        #print(len(self.canvas.axes.patches))
        #self.canvas.draw()
        #self.on_release_for_annotate(None)

    def initializeAnnotationDictionary(self):
        self.currentAnnotationFrame = None
        self.annotationRecordDictionary = {}

    def updateAnnotationDictionary(self):

        # When you move away from current Frame call this
        previousFrame = self.currentAnnotationFrame
        if previousFrame is not None:
            self.annotationRecordDictionary[str(
                previousFrame)] = self.canvas.axes.patches

    def getAnnotationDictionary(self):
        return self.annotationRecordDictionary

    def applyAnnotationDictionary(self, frameNumber):
        self.currentAnnotationFrame = frameNumber
        self.canvas.axes.patches = []
        if str(frameNumber) in self.annotationRecordDictionary.keys():
            for patch in self.annotationRecordDictionary[str(frameNumber)]:
                self.canvas.axes.add_patch(patch)

    def setAnnotationDictionary(self):
        pass
Ejemplo n.º 20
0
class Plot():
    '''
    A container for holding variables and methods needed for 
    animating the interactive plot, is a child of a PlotPage object
    '''
    
    def __init__(self, parent, seconds, num):
        self.parent = parent
        
        self.scale_pos = 0
        self.num = num                 # Which number servo, for plot title
        self.length = (seconds*2)+1    # Number of nodes in plot, 2 per second
        
        self.node_clicked = False       # Node follows mouse only when clicked
        self.point_index = None        # Track which node has been selected
        
        # For keeping values within range of servo degrees
        self.upper_limit = 179
        self.lower_limit = 0
        self.limit_range = lambda n: max(min(self.upper_limit, n), self.lower_limit)
        
        # Initial Graph -----
        self.fig = Figure(figsize=(10,5), dpi=100)
        self.fig.subplots_adjust(bottom=0.18)
        self.ax = self.fig.add_subplot(111)
        
        self.xs = [i for i in range(self.length)]
        self.ys = [self.parent.parent.node_default_val for i in self.xs]
        
        # To hold values from span selector
        self.span_xs = []
        self.selection = False
        
        
        self.setPlot()
        self.drawPlot()
    
    def setPlot(self):
        '''Elements of the plot which do not need to be redrawn every update '''
        
        self.ax.set_ylim([-10,190])
        self.ax.set_yticks(range(0,190,20))
        
        self.ax.grid(alpha=.5)
        self.ax.set_xlabel('Seconds')
        self.ax.set_ylabel('Degree of Motion', fontname='BPG Courier GPL&GNU',
            fontsize=14)
        
    def clearPlotLines(self):
        '''Remove plotted lines so they can be redrawn'''
        
        self.line.remove()
        self.nodes.remove()
        self.upper.remove()
        self.lower.remove()
       
    def drawPlot(self):
        '''Draw the actual plot'''
        
        self.ax.set_title(label=self.parent.name, fontsize=18, y=1.03)
        
        x_window = 20                   # Num of ticks in 'viewport'
        pos = round(self.scale_pos*2)   # scale_pos is in seconds, pos is in ticks
        pos = max(pos, 0)               
        
        # Confine y-values to within upper and lower limits
        self.ys = [self.limit_range(node) for node in self.ys]
        
        # Only 'x_window' of plot is viewable
        self.ax.set_xlim([pos-.5, pos+x_window+.5])
        self.ax.set_xticks([i for i in range(pos, pos+x_window+1)])
        self.ax.set_xticklabels([i/2 for i in self.ax.get_xticks()])
        for tick in self.ax.get_xticklabels():
            tick.set_rotation(45)
        
        #~ # Plot upper and lower limits
        self.upper, = self.ax.plot(self.xs, [self.upper_limit for i in self.xs],
            'k--', alpha=.6, linewidth=1)
        self.lower, = self.ax.plot(self.xs, [self.lower_limit for i in self.xs],
            'k--', alpha=.6, linewidth=1)
        
        # Line
        self.line, = self.ax.plot(self.xs, self.ys, color='orange',
            markersize=10)
            
        # Clickable nodes
        self.nodes, = self.ax.plot(self.xs, self.ys, 'k.', 
            markersize=10, picker=5.0)
   
    def createSpanSelector(self):
        '''Creates span selector widget'''
    
        return SpanSelector(self.ax, self.spanSelect, 'horizontal',
                    useblit=True, span_stays=False, button=1, minspan=.05,
                    rectprops=dict(alpha=0.50, facecolor='lightskyblue'))
    
    def onNodeClick(self, event):
        '''Which node has been clicked'''
        
        #~ point = event.artist
        index = event.ind
        
        self.point_index = int(index[0])
        
        # Single-click
        if not event.mouseevent.dblclick:
            self.node_clicked = True
            
        # Double-click
        else:
            self.span.active = False
            # If node is double-clicked open popup to change value
            sleep(.1)  # Needs short delay

            current_val = self.ys[self.point_index]
            new_val, ok_cancel = ValuePopup(current_val).show()
            
            # If 'ok button' closed ValuePopup
            if ok_cancel:
                # Update app points in highlight to value from ValuePopup
                if self.selection:
                    for xp in self.span_xs:
                        selected_index = self.xs.index(xp)
                        self.ys[selected_index] = new_val
                else:
                    self.ys[self.point_index] = new_val
            
                self.update()
    
    def spanSelect(self, x_min, x_max):
        '''Callback for span selector'''
        
        constrain = lambda n, n_min, n_max: max(min(n, n_max), n_min)
        
        xmin = constrain(ceil(x_min), 0, len(self.xs))
        xmax = constrain(ceil(x_max), 0, len(self.xs))
        
        selected_xs = self.xs[xmin:xmax]
        selected_ys = self.ys[xmin:xmax]
        
        if len(selected_xs) <= 1:
            self.span_xs = []
            return
        
        # Store selected points into lists
        self.span_xs = selected_xs
        
        # Create rectangle that remains, hightlighting selection
        self.highlight_rect = Rectangle((x_min, -10), width=(x_max-x_min),
            height=200, angle=0, **dict(alpha=0.35, facecolor='lightskyblue'))
        self.ax.add_patch(self.highlight_rect)
        self.selection = True
            
    def onClick(self, event):
        '''Mouse click makes span select go away,
           resets selection arrays'''
        
        # Some events cause canvas to lose focus, causing missed events
        # Bring focus back to canvas 
        self.parent.canvas._tkcanvas.focus_set()
        
        if event.button == 1:
            # If moving a node, deactivate spanselector
            if self.node_clicked:
                self.span.active= False
                
                if self.selection:
                    # Del highlight rect if selected node is not highlighted
                    if self.point_index not in self.span_xs:
                        self.removeHighlight()
                
            # Del highlight rect when anything non-node is clicked    
            elif not self.node_clicked:
                self.removeHighlight()
                
        elif event.button == 3:
            self.removeHighlight()
        
    def onMotion(self, event):
        '''Mouse can drag nodes'''
        
        if self.node_clicked and event.inaxes:
            prev_y_value = self.ys[self.point_index]
            
            # Point follows mouse on y-axis
            self.ys[self.point_index] = int(round(self.limit_range(event.ydata)))
            node_diff = (self.ys[self.point_index] - prev_y_value)
            
            # Update highlighted point lists
            if self.selection:
                
                #~ print('node index', self.point_index)
                for xp in self.span_xs:
                    if xp != self.point_index:
                        self.ys[xp] += node_diff
            
            self.update()
    
    def onRelease(self, event):
        # Spanselector deactivates on certain mouse events,
        # Any button release reactivates it
        self.span.active = True
        
        if self.point_index is not None:
            self.node_clicked = False
            self.point_index = None
    
    def removeHighlight(self):
        '''Removes highlight rect from plot'''
        
        if self.selection:
            self.highlight_rect.remove()
            delattr(self, 'highlight_rect')
            
            self.span_xs = []
            self.selection = False
            
            self.update()
    
    def onDelKey(self, event):
        '''Deletes selected nodes and adds same number of nodes at the default
           value to the end, to maintain routine length '''
           
        if self.selection and event.key=='delete':
            answer = messagebox.askyesno('Delete Nodes',
                message='Delete nodes?')
                
            if answer:
                # Delete selected nodes
                del self.xs[self.span_xs[0]: self.span_xs[-1]+1]
                del self.ys[self.span_xs[0]: self.span_xs[-1]+1]
                
                # Add default nodes to end to maintain routine length
                self.xs += [0 for i in self.span_xs]
                self.ys += [self.parent.parent.node_default_val for i in self.span_xs]
                # Re-number xs
                self.xs = [index for index, val in enumerate(self.xs)]
                
                self.removeHighlight()
                
                self.update()
        
    def update(self):
        '''Re-draw plot after moving a point'''

        self.clearPlotLines()
        
        self.drawPlot()
        self.fig.canvas.draw()
Ejemplo n.º 21
0
class ROIdisplay():
    def on_mousepress(self, event):
        self._cpressed = 1

        self._roix1 = event.xdata
        self._roiy1 = event.ydata

        if self._roidrawline.get() == 1:
            self._citem = Line2D([self._roix1, self._roix1 + 1],
                                 [self._roiy1, self._roiy1 + 1],
                                 c='yellow')
        elif self._roidrawcircle.get() == 1:
            self._citem = Circle((self._roix1, self._roiy1),
                                 1,
                                 fc='none',
                                 ec='yellow')
        elif self._roidrawrect.get() == 1:
            self._citem = Rectangle((self._roix1, self._roiy1),
                                    1,
                                    1,
                                    fc='none',
                                    ec='yellow')

        if self._menucheckRI.get() == 0:
            self._a_img.add_artist(self._citem)
            self._canvas.draw()

    def on_mousedrag(self, event):

        if self._cpressed == 1:
            self._citem.remove()
            self._canvas.draw()
            self._roix2 = event.xdata
            self._roiy2 = event.ydata

            if self._roidrawline.get() == 1:
                self._citem = Line2D([self._roix1, self._roix2],
                                     [self._roiy1, self._roiy2],
                                     c='yellow')
            elif self._roidrawcircle.get() == 1:
                x1 = np.mean([self._roix1, self._roix2])
                y1 = np.mean([self._roiy1, self._roiy2])
                r = np.sqrt((self._roix2 - x1)**2 + (self._roiy2 - y1)**2)
                self._citem = Circle((x1, y1), r, fc='none', ec='yellow')
            elif self._roidrawrect.get() == 1:
                width = np.abs(self._roix2 - self._roix1)
                height = np.abs(self._roiy2 - self._roiy1)
                x1 = np.min([self._roix1, self._roix2])
                y1 = np.min([self._roiy1, self._roiy2])
                self._citem = Rectangle((x1, y1),
                                        width,
                                        height,
                                        fc='none',
                                        ec='yellow')

            if self._menucheckRI.get() == 0:
                self._a_img.add_artist(self._citem)
                self._canvas.draw()

    def on_mouseup(self, event):

        self._cpressed = 0

        self._roix2 = event.xdata
        self._roiy2 = event.ydata

        try:
            self._citem.remove()
        except AttributeError:
            pass

        try:
            length_roipath = len(self._roipath) + 1
        except AttributeError:
            length_roipath = 1

        if self._roidrawline.get() == 1:
            self._citem = Line2D([self._roix1, self._roix2],
                                 [self._roiy1, self._roiy2],
                                 c='orange')
        elif self._roidrawcircle.get() == 1:
            x1 = np.mean([self._roix1, self._roix2])
            y1 = np.mean([self._roiy1, self._roiy2])
            r = np.sqrt((self._roix2 - x1)**2 + (self._roiy2 - y1)**2)
            self._citem = Circle((x1, y1), r, fc='none', ec='orange')
        elif self._roidrawrect.get() == 1:
            width = np.abs(self._roix2 - self._roix1)
            height = np.abs(self._roiy2 - self._roiy1)
            x1 = np.min([self._roix1, self._roix2])
            y1 = np.min([self._roiy1, self._roiy2])
            self._citem = Rectangle((x1, y1),
                                    width,
                                    height,
                                    fc='none',
                                    ec='orange')

        if self._menucheckRI.get() == 0:
            self._a_img.add_artist(self._citem)
            roi_label = self._a_img.text(self._roix2,
                                         self._roiy2,
                                         str(length_roipath),
                                         ha="center",
                                         family='sans-serif',
                                         size=14,
                                         color='yellow')
            self._canvas.draw()

        if self._menucheckMS.get() == 0:
            try:
                self._roipath.append(self._citem)
                self._roilabel.append(roi_label)
            except AttributeError:
                self._roipath = []
                self._roilabel = []
                self._roipath.append(self._citem)
                self._roilabel.append(roi_label)

            if self._roidrawline.get() == 1: roi_type = 'Line '
            elif self._roidrawcircle.get() == 1: roi_type = 'Circle '
            elif self._roidrawrect.get() == 1: roi_type = 'Rectangle '

            if self._menucheckRM.get() == 1:
                self.roiListbox.insert('end', roi_type + str(length_roipath))

        elif self._menucheckMS.get() == 1:
            self._refpath = self._citem

    def noshow_roi(self):

        for item in self._roipath:
            item.remove()
        for item in self._roilabel:
            item.remove()

        if self._menucheckRI == 0: self._canvas.draw()
Ejemplo n.º 22
0
def conv_plot(kernel_size=3, stride=1, padding=True, speed=1):

    assert (kernel_size in [1, 3, 5])
    assert (stride in [1, 2, 3])
    assert (padding in [True, False])
    assert (padding or stride == 1)  # Without padding set stride to 1

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    image = np.random.rand(7, 7)

    # Background
    for ax_id, ax in enumerate(axes):
        ax.axis([0, 11, 0, 11])
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.xaxis.set_ticks_position('none')
        ax.yaxis.set_ticks_position('none')
        for spine in ['bottom', 'top', 'left', 'right']:
            ax.spines[spine].set_visible(False)

    axes[0].set_title('Input $X$')
    for i in range(7):
        for j in range(7):
            axes[0].add_patch(
                Rectangle(xy=(i + 2, j + 2),
                          width=1,
                          height=1,
                          edgecolor='black',
                          facecolor=str(image[i, j])))

    output_size = {1: 7, 2: 4, 3: 3}[stride]
    if not padding:
        output_size -= (kernel_size // 2) * 2
    shift = 2 + (7 - output_size) // 2

    axes[1].set_title('Output $Z$')
    for i in range(output_size):
        for j in range(output_size):
            axes[1].add_patch(
                Rectangle(xy=(i + shift, 10 - j - shift),
                          width=1,
                          height=1,
                          edgecolor='black',
                          facecolor='1.0'))

    for j, i in itertools.product(range(output_size), range(output_size)):
        output_coord = i + shift, 10 - j - shift
        res = Rectangle(output_coord,
                        1,
                        1,
                        edgecolor='black',
                        facecolor=str(np.random.rand()))
        rec_1 = Rectangle(output_coord,
                          1,
                          1,
                          edgecolor='red',
                          facecolor='none',
                          linewidth='3')
        axes[1].add_patch(res)
        axes[1].add_patch(rec_1)

        x, y = i * stride + 2, 8 - j * stride
        x -= kernel_size // 2
        y -= kernel_size // 2
        if not padding:
            x += kernel_size // 2
            y -= kernel_size // 2
        rec_0 = Rectangle((x, y),
                          kernel_size,
                          kernel_size,
                          edgecolor='red',
                          facecolor='none',
                          linewidth='3')
        axes[0].add_patch(rec_0)

        fig.canvas.draw()
        sleep(0.5 / speed)
        rec_0.remove()
        rec_1.remove()
Ejemplo n.º 23
0
class LensGUI:
    def __init__(self,parent):
        self.root = Tk.Tk()

        self.parent = parent
        self.img = self.parent.img
        self.color = self.parent.color

        self.mover = None

        f1 = Figure((12.06,12.06))
        a1 = f1.add_axes([0,101./201,100./201,100./201])
        self.img1 = a1.imshow(self.img,origin='bottom',interpolation='nearest')
        a1.set_xticks([])
        a1.set_yticks([])
        xlim = a1.get_xlim()
        ylim = a1.get_ylim()

        a2 = f1.add_axes([101./201,101./201,100./201,100./201])
        self.img2 = a2.imshow(self.img,origin='bottom',interpolation='nearest')
        a2.set_xlim(xlim)
        a2.set_ylim(ylim)
        a2.set_xticks([])
        a2.set_yticks([])

        a3 = f1.add_axes([0.,0.,100./201,100./201])
        self.img3 = a3.imshow(self.img*0+1,origin='bottom',interpolation='nearest')
        a3.set_xlim(xlim)
        a3.set_ylim(ylim)
        a3.set_xticks([])
        a3.set_yticks([])

        a4 = f1.add_axes([101./201,0.,100./201,100./201])
        a4.imshow(self.parent.b*0)
        a4.cla()
        a4.set_xlim(xlim)
        a4.set_ylim(ylim)
        a4.set_xticks([])
        a4.set_yticks([])

        canvas = FigureCanvasTkAgg(f1,master=self.root)
        canvas.show()
        canvas.get_tk_widget().pack(side=Tk.TOP,fill=Tk.BOTH,expand=1)
        toolbar = NavigationToolbar2TkAgg(canvas,self.root )
        toolbar.update()
        canvas._tkcanvas.pack(side=Tk.TOP, fill=Tk.BOTH, expand=1)
        bFrame = Tk.Frame(self.root)
        bFrame.pack(side=Tk.TOP,fill=Tk.BOTH,expand=1)

        self.f1 = f1
        self.a1 = a1
        self.a2 = a2
        self.a3 = a3
        self.a4 = a4
        self.bFrame = bFrame

        self.canvas = canvas
        self.toolbar = toolbar

        self.rubberBox = None

        self.addButtons()

    def addButtons(self):
        self.activeButton = None
        self.bAGtext = Tk.StringVar()
        self.bAGtext.set('Add Galaxy')
        self.buttonAG = Tk.Button(self.toolbar,textvariable=self.bAGtext,command=self.parent.addGal,width=10)
        self.buttonAG.pack(side=Tk.LEFT)

        self.bALtext = Tk.StringVar()
        self.bALtext.set('Add Lens')
        self.buttonAL = Tk.Button(self.toolbar,textvariable=self.bALtext,command=self.parent.addLens,width=10)
        self.buttonAL.pack(side=Tk.LEFT)

        self.bAStext = Tk.StringVar()
        self.bAStext.set('Add Source')
        self.buttonAS = Tk.Button(self.toolbar,textvariable=self.bAStext,command=self.parent.addSrc,width=10)
        self.buttonAS.pack(side=Tk.LEFT)
        self.buttonAS.configure(state='disabled')

        self.buttonFit = Tk.Button(self.toolbar,text='Fit Light',command=self.parent.fitLight,width=10)
        self.buttonFit.pack(side=Tk.LEFT)
        #self.buttonFit.configure(state='disabled')

        self.bOpttext = Tk.StringVar()
        self.bOpttext.set('Optimize')
        self.buttonOptimize = Tk.Button(self.toolbar,textvariable=self.bOpttext,command=self.parent.optimize,width=10)
        self.buttonOptimize.pack(side=Tk.LEFT)
        #self.buttonOptimize.configure(state='disabled')

        self.buttonSave = Tk.Button(self.bFrame,text='Save',command=self.parent.saveState,width=10)
        self.buttonSave.pack(side=Tk.LEFT)

        self.buttonLoad = Tk.Button(self.bFrame,text='Load',command=self.parent.loadState,width=10)
        self.buttonLoad.pack(side=Tk.LEFT)

        self.bAMtext = Tk.StringVar()
        self.bAMtext.set('Add Mask')
        self.buttonMask = Tk.Button(self.bFrame,textvariable=self.bAMtext,command=self.addMask,width=10)
        self.buttonMask.pack(side=Tk.LEFT)


    def deactivateButtons(self):
        if self.toolbar.mode!='':
            self.toolbar.zoom()
            self.toolbar.pan()
            self.toolbar.pan()
        if self.activeButton==self.buttonAG:
            self.bAGtext.set('Add Galaxy')
            self.canvas.mpl_disconnect(self.pressid)
        elif self.activeButton==self.buttonAL:
            self.bALtext.set('Add Lens')
            self.canvas.mpl_disconnect(self.pressid)
        elif self.activeButton==self.buttonAS:
            self.bAStext.set('Add Source')
            self.canvas.mpl_disconnect(self.pressid)
        elif self.activeButton==self.buttonMask:
            self.bAMtext.set('Add Mask')
            self.canvas.mpl_disconnect(self.pressid)
            self.canvas.mpl_disconnect(self.moveid)
            self.canvas.mpl_disconnect(self.releaseid)
        self.pressid = None
        self.releaseid = None
        self.activeButton = None


    def addMask(self,loaded=False):
        from matplotlib.patches import Rectangle
        if loaded and self.parent.mask is not None:
            import numpy
            y,x = numpy.where(self.parent.mask==1)
            x0,x1,y0,y1 = x.min(),x.max(),y.min(),y.max()
            self.rubberBox = Rectangle((x0,y0),x1-x0,y1-y0,fc='none',ec='w')
            self.a1.add_patch(self.rubberBox)
            self.canvas.draw()
            return
        if self.activeButton==self.buttonMask:
            self.deactivateButtons()
            return
        self.deactivateButtons()
        self.xmask = None
        def onPress(event):
            axes = event.inaxes
            if axes==self.a1:
                self.xmask = event.xdata
                self.ymask = event.ydata
            if self.rubberBox is not None:
                self.rubberBox.remove()
                self.rubberBox = None
        def onMove(event):
            if self.xmask is None:
                return
            axes = event.inaxes
            if axes==self.a1:
                x,y = event.xdata,event.ydata
                dx = x-self.xmask
                dy = y-self.ymask
                if self.rubberBox is None:
                    self.rubberBox = Rectangle((self.xmask,self.ymask),
                                                dx,dy,fc='none',ec='w')
                    self.a1.add_patch(self.rubberBox)
                else:
                    self.rubberBox.set_height(dy)
                    self.rubberBox.set_width(dx)
                self.canvas.draw()
        def onRelease(event):
            dy = int(self.rubberBox.get_height())
            dx = int(self.rubberBox.get_width())
            x0,y0 = int(self.xmask),int(self.ymask)
            x1,y1 = x0+dx,y0+dy
            self.parent.mask = self.parent.imgs[0]*0
            self.parent.mask[y0:y1,x0:x1] = 1
            self.parent.mask = self.parent.mask==1
            self.deactivateButtons()
        self.pressid = self.canvas.mpl_connect('button_press_event',onPress)
        self.moveid = self.canvas.mpl_connect('motion_notify_event',onMove)
        self.releaseid = self.canvas.mpl_connect('button_release_event',onRelease)
        self.bAMtext.set('Cancel')
        self.activeButton = self.buttonMask


    def showResid(self):
        if self.parent.models is None:
            self.a2.imshow(self.parent.img,origin='bottom',
                            interpolation='nearest')
            self.a3.cla()
            self.a3.set_xticks([])
            self.a3.set_yticks([])
            self.canvas.show()
            return
        models = self.parent.models
        imgs = self.parent.imgs
        nimgs = self.parent.nimgs
        if self.color is not None:
            if nimgs==2:
                b = imgs[0]-models[0]
                r = imgs[1]-models[1]
                g = (b+r)/2.
                resid = self.color.colorize(b,g,r)
                b = models[0]
                r = models[1]
                g = (b+r)/2.
                model = self.color.colorize(b,g,r,newI=True)
            else:
                b = imgs[0]-models[0]
                g = imgs[1]-models[1]
                r = imgs[2]-models[2]
                resid = self.color.colorize(b,g,r)
                b = models[0]
                g = models[1]
                r = models[2]
                model = self.color.colorize(b,g,r,newI=True)
        else:
            resid = imgs[0]-models[0]
            model = models[0]
            self.img3.set_clim([0.,model.max()])
        #self.a2.imshow(resid,origin='bottom',interpolation='nearest')
        #self.a3.imshow(model,origin='bottom',interpolation='nearest')
        self.img2.set_data(resid)
        self.img3.set_data(model)
        self.canvas.draw()

    def redrawSymbols(self):
        import objectMover
        if self.mover is not None:
            self.mover.remove()
        self.mover = objectMover.ObjMover(self.parent,self.a4,self.canvas)
Ejemplo n.º 24
0
class Picker_plot (object):
    """
    plot data so that a user can select points
    
    Parameters
    -----------
    ax : matplotlib ax instance
    data : Nx2 data array
    """
    def __init__ (self, ax, data, inliers=None):
        self.ax = ax
        self.data = data
        self.ax.figure.canvas.mpl_connect ('pick_event', self.on_pick)
        self.ax.figure.canvas.mpl_connect ('button_press_event', self.on_click)
        self.ax.figure.canvas.mpl_connect ('motion_notify_event', self.on_motion_notify)
        self.ax.figure.canvas.mpl_connect ('button_release_event', self.on_release)

        data_range = range (data.shape[0])
        if not inliers is None:
            self.inliers = list (inliers)
            self.outliers = list (np.setdiff1d (data_range, inliers))
        else:
            self.inliers = data_range
            self.outliers = []
        self.line = None
        self.oline = None

        self.rect = None
        self.x0 = None
        self.y0 = None
        
        self.draw ()

    def draw (self):
        if not self.line:
            self.line, = self.ax.plot (self.data[:,0], self.data[:,1], 'bo', picker=3)

        if self.oline:
            self.oline.remove ()
            self.oline = None

        self.oline, = self.ax.plot (self.data[self.outliers,0], 
                self.data[self.outliers,1], 'ro')

    def update_data (self, index):
        for ind in index:
            if ind in self.outliers: 
                print '{Picker_plot} user removed index:', ind
                self.inliers.append (ind)
                self.outliers.remove (ind)
            else:
                print '{Picker_plot} user added index:', ind
                self.outliers.append (ind)
                self.inliers.remove (ind)

        self.draw ()
        self.ax.figure.canvas.draw ()

    def on_pick (self, event):
        if event.artist != self.line: return
        self.update_data (event.ind)

    def on_click (self, event):
        toolbar = plt.get_current_fig_manager ().toolbar
        if toolbar.mode != '': return
        self.x0, self.y0 = event.xdata, event.ydata

    def on_motion_notify (self, event):
        if not self.x0 or not event.inaxes: return

        self.x1, self.y1 = event.xdata, event.ydata
        if self.rect:
            self.rect.remove ()
            self.rect = None
        self.rect = Rectangle ([self.x0,self.y0], event.xdata-self.x0,
                event.ydata-self.y0, color='k', fc='none',lw=1)
        self.ax.add_artist (self.rect)
        self.ax.figure.canvas.draw ()

    def on_release (self, event):
        if not self.x0: return
        if event.inaxes:
            self.x1, self.y1 = event.xdata, event.ydata
        x0 = min (self.x0, self.x1)
        y0 = min (self.y0, self.y1)
        x1 = max (self.x0, self.x1)
        y1 = max (self.y0, self.y1)

        self.x0,self.y0 = None, None
        self.x1,self.y1 = None, None
        if self.rect:
            self.rect.remove ()
            self.rect = None

        ind = np.where ((self.data[:,0]>x0) & (self.data[:,0]<x1) & (self.data[:,1]>y0) & (self.data[:,1]<y1))[0]
        self.update_data (ind)
Ejemplo n.º 25
0
class AxesHelper(object):

    def __init__(self, ax, data_client, selection_callback=None):
        self.ax = ax
        self.data_client = data_client
        self.selection_in_progress = False
        self._rectangle = None
        self.x_attribute = None
        self.y_attribute = None
        self.selection = {}
        self.ax.figure.canvas.mpl_connect('button_press_event', self.start_selection)
        self.ax.figure.canvas.mpl_connect('motion_notify_event', self.update_selection)
        self.ax.figure.canvas.mpl_connect('button_release_event', self.finalize_selection)
        self.selection_callback = selection_callback

    def start_selection(self, event):
        if event.inaxes is not self.ax:
            return
        if event.button != 3:
            return
        if self.selection_in_progress:
            raise ValueError("Selection already in progress, unexpected error")
        self.selection_in_progress = True
        if self.y_attribute is None:
            ymin, ymax = self.ax.get_ylim()
            self.selection = {'type': 'range',
                              'x_min': event.xdata,
                              'x_max': event.xdata,
                              'x_attribute': self.x_attribute}
            self._rectangle = Rectangle((self.selection['x_min'], ymin),
                                        width=0, height=ymax - ymin, edgecolor='red', facecolor='none')
        else:
            self.selection = {'type': 'rectangle',
                              'x_min': event.xdata,
                              'x_max': event.xdata,
                              'y_min': event.ydata,
                              'y_max': event.ydata,
                              'x_attribute': self.x_attribute,
                              'y_attribute': self.y_attribute
                              }
            self._rectangle = Rectangle((self.selection['x_min'], self.selection['y_min']),
                                        width=0, height=0, edgecolor='red', facecolor='none')
        self.ax.add_patch(self._rectangle)

    def update_selection(self, event):
        if event.inaxes is not self.ax:
            return
        if not self.selection_in_progress:
            return

        self.selection['x_max'] = event.xdata
        width = self.selection['x_max'] - self.selection['x_min']
        self._rectangle.set_width(width)

        if 'y_max' in self.selection:
            self.selection['y_max'] = event.ydata
            height = self.selection['y_max'] - self.selection['y_min']
            self._rectangle.set_height(height)

        self.ax.figure.canvas.draw()

    def finalize_selection(self, event):
        if event.inaxes is not self.ax:
            return
        if not self.selection_in_progress:
            return
        self._rectangle.remove()
        self.selection_in_progress = False
        self.ax.figure.canvas.draw()
        if self.selection_callback is not None:
            self.selection_callback(self.selection)

    def set_selection(self, selection):
        self.selection = selection
        self.update()

    def histogram1d(self, x_attribute, x_nbin):
        self.x_attribute = x_attribute
        self.x_nbin = x_nbin
        self.histogram = None
        self.ax.figure.canvas.mpl_connect('button_release_event', self.update_histogram1d)
        self.update = self.update_histogram1d

    def update_histogram1d(self, event=None):
        if self.selection_in_progress:
            return
        x_range = self.ax.get_xlim()
        midpoints, array = self.data_client.histogram1d(self.x_attribute, x_range, self.x_nbin, selection=self.selection)
        if self.histogram is not None:
            for patch in self.histogram:
                patch.remove()
        self.ax.set_autoscale_on(False)
        self.histogram = self.ax.plot(midpoints, array, drawstyle='steps-mid', color='k')
        self.ax.set_autoscale_on(True)
        self.ax.figure.canvas.draw()

    def histogram2d(self, x_attribute, y_attribute, x_nbin, y_nbin):
        self.x_attribute = x_attribute
        self.y_attribute = y_attribute
        self.x_nbin = x_nbin
        self.y_nbin = y_nbin
        self.image = None
        self.ax.figure.canvas.mpl_connect('button_release_event', self.update_histogram2d)
        self.update = self.update_histogram2d

    def update_histogram2d(self, event=None):
        if self.selection_in_progress or (event is not None and event.inaxes is not self.ax):
            return
        x_range = self.ax.get_xlim()
        y_range = self.ax.get_ylim()
        array = self.data_client.histogram2d(self.x_attribute, self.y_attribute,
                                             x_range, y_range, self.x_nbin, self.y_nbin, selection=self.selection)
        self.ax.set_autoscale_on(False)
        if self.image is None:
            self.image = self.ax.imshow(array, aspect='auto', extent=[0, 1, 0, 1], transform=self.ax.transAxes)
        else:
            self.image.set_data(array)
        self.ax.set_autoscale_on(True)
        self.ax.figure.canvas.draw()
Ejemplo n.º 26
0
class PixelInteractor(QObject):

    epsilon = 10
    showverts = True
    mySignal = pyqtSignal(str)
    modSignal = pyqtSignal(str)

    
    def __init__(self,ax,corner,width,angle=0.):
        super().__init__()
        from matplotlib.patches import Rectangle
        from matplotlib.lines import Line2D
        # from matplotlib.artist import Artist
        # To avoid crashing with maximum recursion depth exceeded
        import sys
        sys.setrecursionlimit(10000) # 10000 is 10x the default value

        self.type = 'Pixel'
        height = width
        self.ax = ax
        self.angle  = angle
        self.width  = width
        self.height = width
        # print('corner is ', corner)
        self.rect = Rectangle(corner,width,height,edgecolor='Lime',facecolor='none',angle=angle,fill=False,animated=True)
        self.ax.add_patch(self.rect)
        self.canvas = self.rect.figure.canvas

        x,y = self.compute_markers()
        self.line = Line2D(x, y, marker='s', linestyle=None, linewidth=0., markerfacecolor='g', animated=True)
        self.ax.add_line(self.line)

        self.cid = self.rect.add_callback(self.rectangle_changed)
        self._ind = None  # the active point

        self.connect()

        self.aperture = self.rect
        self.press = None
        self.lock = None


    def compute_markers(self):

        # theta0 = self.rect.angle / 180.*np.pi
        w0 = self.rect.get_width()
        # h0 = self.rect.get_height()
        x0,y0 = self.rect.get_xy()
        angle0 = self.rect.angle

        x = [x0+w0/np.sqrt(2.)*np.sin((45.-angle0)*np.pi/180.)]
        y = [y0+w0/np.sqrt(2.)*np.cos((45.-angle0)*np.pi/180.)]

        self.xy = [(x,y)]
        return x, y

    def connect(self):
        self.cid_draw = self.canvas.mpl_connect('draw_event', self.draw_callback)
        self.cid_press = self.canvas.mpl_connect('button_press_event', self.button_press_callback)
        self.cid_release = self.canvas.mpl_connect('button_release_event', self.button_release_callback)
        self.cid_motion = self.canvas.mpl_connect('motion_notify_event', self.motion_notify_callback)
        self.cid_key = self.canvas.mpl_connect('key_press_event', self.key_press_callback)
        self.canvas.draw_idle()

        
    def disconnect(self):
        self.canvas.mpl_disconnect(self.cid_draw)
        self.canvas.mpl_disconnect(self.cid_press)
        self.canvas.mpl_disconnect(self.cid_release)
        self.canvas.mpl_disconnect(self.cid_motion)
        self.canvas.mpl_disconnect(self.cid_key)
        self.rect.remove()
        self.line.remove()
        self.canvas.draw_idle()
        self.aperture = None
        
    def draw_callback(self, event):
        self.background = self.canvas.copy_from_bbox(self.ax.bbox)
        self.ax.draw_artist(self.rect)
        self.ax.draw_artist(self.line)


    def rectangle_changed(self, rect):
        'this method is called whenever the polygon object is called'
        # only copy the artist props to the line (except visibility)
        vis = self.line.get_visible()
        Artist.update_from(self.line, rect)
        self.line.set_visible(vis)  

        
    def get_ind_under_point(self, event):
        'get the index of the point if within epsilon tolerance'

        x, y = self.xy[0]
        d = np.hypot(x - event.xdata, y - event.ydata)

        if d >= self.epsilon:
            ind = None
        else:
            ind = 0
            
        return ind

    def button_press_callback(self, event):
        'whenever a mouse button is pressed'
        if not self.showverts:
            return
        if event.inaxes is None:
            return
        if event.button != 1:
            return
        self._ind = self.get_ind_under_point(event)
        x0, y0 = self.rect.get_xy()
        w0, h0 = self.rect.get_width(), self.rect.get_height()
        theta0 = self.rect.angle/180*np.pi
        self.press = x0, y0, w0, h0, theta0, event.xdata, event.ydata
        self.xy0 = self.xy

        self.lock = "pressed"


    def key_press_callback(self, event):
        'whenever a key is pressed'
        if not event.inaxes:
            return

        if event.key == 't':
            self.showverts = not self.showverts
            self.line.set_visible(self.showverts)
            if not self.showverts:
                self._ind = None
        elif event.key == 'd':
            self.mySignal.emit('rectangle deleted')

        self.canvas.draw_idle()

    def button_release_callback(self, event):
        'whenever a mouse button is released'
        if not self.showverts:
            return
        if event.button != 1:
            return
        self._ind = None
        self.press = None
        self.lock = "released"
        self.background = None
        # To get other aperture redrawn
        self.canvas.draw_idle()
        

    def motion_notify_callback(self, event):
        'on mouse movement'

        if not self.showverts:
            return
        if self._ind is None:
            return
        if event.inaxes is None:
            return
        if event.button != 1:
            return

        x0, y0, w0, h0, theta0, xpress, ypress = self.press
        self.dx = event.xdata - xpress
        self.dy = event.ydata - ypress
        self.update_rectangle()

        # Redraw rectangle and points
        self.canvas.restore_region(self.background)
        self.ax.draw_artist(self.rect)
        self.ax.draw_artist(self.line)
        self.canvas.update()
        self.canvas.flush_events()
        
        # alternative (slower)
        # self.canvas.draw_idle()

        # Notify callback
        self.modSignal.emit('rectangle modified')

    def update_rectangle(self):

        x0, y0, w0, h0, theta0, xpress, ypress = self.press
        dx, dy = self.dx, self.dy
        
        if self.lock == "pressed":
            self.lock = "move"
        elif self.lock == "move":
            if x0+dx < 0:
                xn = x0
                dx = 0
            else:
                xn = x0+dx
            if y0+dy < 0:
                yn = y0
                dy = 0
            else:
                yn = y0+dy
            self.rect.set_xy((xn,yn))
            # update line
            self.xy = [(i+dx,j+dy) for (i,j) in self.xy0]
            # Redefine line
            self.line.set_data(zip(*self.xy))
            self.updateMarkers()

    def updateMarkers(self):
        # update points
        x,y = self.compute_markers()
        self.line.set_data(x,y)
Ejemplo n.º 27
0
class RectangleBuilder(AxesWidget):
    """
    class defined to trace lines on an existing figure
    the class one time defined calculate few attributes
    self.origin = origin of the line
    self.vect = vector represented
    self.mod = lenght of the line
    self.fline = line object passing grom the two point
    """

    def __init__(self, ax, callback=None, useblit=False):
        super().__init__(ax)

        self.useblit = useblit and self.canvas.supports_blit
        if self.useblit:
            self.background = self.canvas.copy_from_bbox(self.ax.bbox)

        self.line = LineBuilder(ax, callback=self.line_callback,
                                useblit=useblit, linekargs={'color': 'red'})
        self.callback = callback
        # self.canvas.widgetlock(self.line)
        # self.__xtl = []
        return

    def line_callback(self, verts):
        x0, y0 = verts[0]
        x1, y1 = verts[1]
        self.line.origin = np.array([x0, y0])
        self.line.vect = np.array([x1 - x0, y1 - y0])
        self.line.mod = np.sqrt(self.line.vect @ self.line.vect)
        self.line.angle = -np.arctan2(*self.line.vect) / rpd
        self.width = 0.0
        self.Rleft = Rectangle(self.line.origin, self.width, self.line.mod,
                               self.line.angle, color='r', alpha=0.3)
        self.Rright = Rectangle(self.line.origin, -self.width, self.line.mod,
                                self.line.angle, color='r', alpha=0.3)
        self.ax.add_patch(self.Rleft)
        self.ax.add_patch(self.Rright)
        self.connect_event('button_press_event', self.onrelease)
        self.connect_event('motion_notify_event', self.onmove)
        if self.useblit:
            self.canvas.restore_region(self.background)
            self.ax.draw_artist(self.Rleft)
            self.ax.draw_artist(self.Rright)
            self.canvas.blit(self.ax.bbox)
        else:
            self.canvas.draw_idle()


    def onrelease(self, event):
        if self.ignore(event):
            return
        if self.width:
            self.callback(self.line.origin, self.line.vect,
                          self.width)
            self.Rleft.remove()
            self.Rright.remove()
        self.canvas.draw_idle()
        self.disconnect_events()


    def onmove(self, event):
        if self.ignore(event):
            return
        if event.inaxes != self.ax:
            return
        # if event.button != 1:
        #    return

        coor = np.array([event.xdata, event.ydata])
        dist = np.abs(np.cross(self.line.vect, coor - self.line.origin))

        self.width = dist / self.line.mod

        self.Rleft.set_width(self.width)
        self.Rright.set_width(-self.width)
        if self.useblit:
            self.canvas.restore_region(self.background)
            self.canvas.blit(self.ax.bbox)
        else:
            self.canvas.draw_idle()
class RectROI(ROI):

    def __init__(self, ax, fig, canvas, red=0.5, green=0.5, blue=0.5):

        ROI.__init__(self, ax, fig, canvas)

        self.x0 = 0
        self.y0 = 0
        self.x1 = 0
        self.y1 = 0
        self.line_color = (red, green, blue)
        self.rect = None
        return

    def button_press_callback(self, event):

        if event.inaxes:
            if event.button == 1:  # If you press the left mouse button
                if self.rect is None:
                    self.x0 = event.xdata
                    self.y0 = event.ydata
        return

    def button_release_callback(self, event):
        # When the user releases the mouse button, make sure the ROI line
        # no longer moves with the mouse.
        if event.button == 1:

            if self.rect is None:
                self.x1 = event.xdata
                self.y1 = event.ydata
                width = self.x1 - self.x0
                height = self.y1 - self.y0
                self.rect = Rectangle((self.x0, self.y0), width, height, color=self.line_color, fill=False, picker=True,
                                      visible=True, figure=self.fig)

                ax = event.inaxes
                ax.add_patch(self.rect)
                self.fig.canvas.draw()

        self.grab_line = None

        return

    def motion_notify_callback(self, event):
        '''
        This is called when the user moves the mouse over the plot.
        It will change the size or position of the ROI.
        '''

        if event.inaxes:
            if (event.button == 1) and (not (self.grab_line is None)):
                # Change the position of the bottom right corner of the ROI
                # as the mouse is dragged across the image.
                self.x1 = event.xdata
                self.y1 = event.ydata
                width = self.x1 - self.x0
                height = self.y1 - self.y0

                self.rect.set_width(width)
                self.rect.set_height(height)

                self.fig.canvas.draw()

            if (event.button == 3) and (not (self.grab_line is None)):
                # Change the position of the top left corner of the ROI
                # as the mouse is dragged across the image.
                self.x0 = event.xdata
                self.y0 = event.ydata

                self.rect.set_xy((self.x0, self.y0))

                self.fig.canvas.draw()

        return

    def object_picked_callback(self, event):
        # Set the line grabbed to the object that is clicked on.
        contains, attrd = self.rect.contains(event.mouseevent)
        if contains:
            self.grab_line = event.artist
        return

    def AddLines(self):
        if self.rect is not None:
            self.ax.add_patch(self.rect)
            self.ax.figure.canvas.draw()

        return

    def RemoveLines(self):
        try:
            self.rect.remove()
            self.fig.canvas.draw()
        except AttributeError:
            return
        return

    def GetDimensions(self):
        dim_list = []
        dim_list.append(self.x0)
        dim_list.append(self.y0)
        dim_list.append(self.x1)
        dim_list.append(self.y1)

        return dim_list

    def EditROI(self):

        if self.x0 == 0 or self.y0 == 0 or self.x1 == 0 or self.y1 == 0:
            return

        width = self.x0 - self.x1
        height = self.y0 - self.y1

        if self.rect is None:
            self.rect = Rectangle((self.x1, self.y1), width, height, color=self.line_color, fill=False, picker=True,
                                      visible=True, figure=self.fig, axes=self.fig.gca())
            ax = self.fig.gca()
            ax.add_patch(self.rect)
        else:
            self.rect.set_height(height)
            self.rect.set_width(width)
            self.rect.set_xy((self.x1, self.y1))

        self.fig.canvas.draw()

        return

    def SetXY(self, x, y, corner=1):

        if corner == 1:
            self.x0 = x
            self.y0 = y

        else:
            self.x1 = x
            self.y1 = y

        return
Ejemplo n.º 29
0
class ROIdisplay():
    def on_mousepress(self, event):
        self._cpressed = 1

        self._roix1 = event.xdata
        self._roiy1 = event.ydata

        if self._drawmethod == 0:
            self._citem = Circle((self._roix1, self._roiy1),
                                 1,
                                 fc='none',
                                 ec='yellow')
        elif self._drawmethod == 1:
            self._citem = Rectangle((self._roix1, self._roiy1),
                                    1,
                                    1,
                                    fc='none',
                                    ec='yellow')
        self._a_img.add_artist(self._citem)
        self._canvas.draw()

    def on_mousedrag(self, event):

        if self._cpressed == 1:
            self._citem.remove()
            self._canvas.draw()
            self._roix2 = event.xdata
            self._roiy2 = event.ydata

            if self._drawmethod == 0:
                x1 = np.mean([self._roix1, self._roix2])
                y1 = np.mean([self._roiy1, self._roiy2])
                r = np.sqrt((self._roix2 - x1)**2 + (self._roiy2 - y1)**2)
                self._citem = Circle((x1, y1), r, fc='none', ec='yellow')
            elif self._drawmethod == 1:
                width = np.abs(self._roix2 - self._roix1)
                height = np.abs(self._roiy2 - self._roiy1)
                x1 = np.min([self._roix1, self._roix2])
                y1 = np.min([self._roiy1, self._roiy2])
                self._citem = Rectangle((x1, y1),
                                        width,
                                        height,
                                        fc='none',
                                        ec='yellow')
            self._a_img.add_artist(self._citem)
            self._canvas.draw()

    def on_mouseup(self, event):

        self._roix2 = event.xdata
        self._roiy2 = event.ydata

        self._citem.remove()

        try:
            length_roipath = len(self._roipath) + 1
        except AttributeError:
            length_roipath = 1

        if self._drawmethod == 0:
            x1 = np.mean([self._roix1, self._roix2])
            y1 = np.mean([self._roiy1, self._roiy2])
            r = np.sqrt((self._roix2 - x1)**2 + (self._roiy2 - y1)**2)
            self._citem = Circle((x1, y1), r, fc='none', ec='orange')
        elif self._drawmethod == 1:
            width = np.abs(self._roix2 - self._roix1)
            height = np.abs(self._roiy2 - self._roiy1)
            x1 = np.min([self._roix1, self._roix2])
            y1 = np.min([self._roiy1, self._roiy2])
            self._citem = Rectangle((x1, y1),
                                    width,
                                    height,
                                    fc='none',
                                    ec='orange')
        self._a_img.add_artist(self._citem)
        roi_label = self._a_img.text(self._roix2,
                                     self._roiy2,
                                     str(length_roipath),
                                     ha="center",
                                     family='sans-serif',
                                     size=14,
                                     color='yellow')

        self._canvas.draw()
        self._cpressed = 0

        try:
            self._roipath.append(self._citem)
            self._roilabel.append(roi_label)
        except AttributeError:
            self._roipath = []
            self._roilabel = []
            self._roipath.append(self._citem)
            self._roilabel.append(roi_label)

        if self._drawmethod == 0: roi_type = 'Circle '
        elif self._drawmethod == 1: roi_type = 'Rectangle '
        elif self._drawmethod == 2: roi_type = 'Polygon '
        self.roiListbox.insert('end', roi_type + str(length_roipath))

        self.roiselectallButton.state(['!disabled'])
        self.roiclearallButton.state(['!disabled'])
        self.roideleteallButton.state(['!disabled'])
        self.roikeepallButton.state(['!disabled'])

    def draw_selec(self, event):

        for item in self._roipath:
            item.set_ec('orange')
        for item in self.roiListbox.curselection():
            self._roipath[item].set_ec('green')

        self._canvas.draw()

    def show_roi(self):

        for item in self._roipath:
            self._a_img.add_artist(item)
        for item in self._roilabel:
            self._a_img.add_artist(item)
        self._canvas.draw()

    def noshow_roi(self):

        for item in self._roipath:
            item.set_ec('orange')
            item.remove()
        for item in self._roilabel:
            item.remove()
        self._canvas.draw()
Ejemplo n.º 30
0
class MPLCanvasWrapper(Gtk.VBox):
    """
    Class that wraps around the Matplotlib Canvas/Figure/Subplot to be used with
    GTK3.
    Supports the following:
            -) Zoom (rectangle, scroll wheel)
            -) Display coordinates
            -) pipes changing the plot settings through to the mpl figure
    """
    __gsignals__ = {
        "update_request":
        (GObject.SIGNAL_RUN_FIRST, GObject.TYPE_NONE, (int, )),
    }

    def append_legend_entries_flag(self, flag=True):

        self.show_legend_entries = np.append(self.show_legend_entries, [flag])

    def reset_legend_entries_flags(self):

        self.show_legend_entries = np.array([], 'bool')

    def change_settings(self, *_):
        """
        Callback for the settings button.
        Calls a popup window from Dialogs.py and asks the user to specify
        plot settings
        """

        self.zoom_out(self)

        sd = SettingsDialog()

        response, settings = sd.run(parent=self._main_window,
                                    old_settings=self.get_settings())

        if response == 1:

            self.set_settings(settings)
            self.emit("update_request", self.nbp)

            return settings

        else:

            return None

    def draw_event(self, *_):
        """
        Event handler for any draw operation on the canvas
        Used here to set dummy x-axis to the same scaling as the primary x-axis
        (workaround for coordinate display with mouse-over events)
        """

        self.set_dummy_xlim()

        return 0

    def unhide_secondary_axis(self):
        """
        Secondary axis cannot be deleted without remakeing the whole canvas,
        therefore, it will only be hidden if it is not used...
        """
        if self.show_secondary_axis:
            return 0

        self.show_secondary_axis = True
        self.secondary_axis.yaxis.set_visible(True)
        # self.secondary_axis.set_frame_on(True)

        self.draw_idle()

        return 0

    def hide_secondary_axis(self):
        """
        Secondary axis cannot be deleted without remakeing the whole canvas,
        therefore, it will only be hidden if it is not used...
        """
        # self.secondary_axis.set_ylabel("")
        # self.secondary_axis.set_yticks([])
        if not self.show_secondary_axis:
            return 0

        self.show_secondary_axis = False
        self.secondary_axis.yaxis.set_visible(False)
        # self.secondary_axis.set_frame_on(False)

        self.draw_idle()

        return 0

    def zoom_out(self, *_):
        """
        """

        if self.zoom_flag:

            self.zoom_flag = False

            self.axis.set_xlim(self.maxlimits[:2])
            self.axis.set_ylim(self.maxlimits[2:])

            self.set_autoscale(self.autoscale_old[0], "x")
            self.set_autoscale(self.autoscale_old[1], "y")

            if self.xtime:
                self.set_xtime()

            self.canvas.draw_idle()

        return 0

    def draw_zoom_box(self, *_):
        """
        Draw the zoom to rect rubberband rectangle
        """

        if not (self.box_start is None or self.box_end is None):

            if self.zoom_box is not None:
                self.zoom_box.remove()

            xy = (min([self.box_start[0], self.box_end[0]]),
                  min([self.box_start[1], self.box_end[1]]))

            width = abs(self.box_start[0] - self.box_end[0])
            height = abs(self.box_start[1] - self.box_end[1])

            self.zoom_box = Rectangle(xy=xy,
                                      width=width,
                                      height=height,
                                      linestyle='dashed',
                                      linewidth=1,
                                      color='grey',
                                      fill=False)

            self.axis.add_artist(self.zoom_box)

            self.canvas.draw_idle()

        return 0

    def mouse_move(self, event):
        """
        """

        if event.inaxes:

            x, y = event.xdata, event.ydata

            self.z_label.set_text("| %.4e" % x)
            self.r_label.set_text("%.4e | " % y)

            if self.holddown:

                if self.box_start_px[0] != event.x and \
                                self.box_start_px[1] != event.y:
                    self.box_end = x, y
                    self.draw_zoom_box(self)

            self.inaxes = True
            self.mouse_x = x
            self.mouse_y = y

        else:
            self.inaxes = False

        return 0

    # TODO: Unused?
    @staticmethod
    def axes_enter_callback(event):

        x, y = event.xdata, event.ydata
        del x, y

        return 0

    @staticmethod
    def axes_leave_callback(event):

        x, y = event.xdata, event.ydata
        del x, y

        return 0

    def button_pressed(self, event):
        """
        """

        leg = self.get_legend()

        if event.button == 3:
            # Action on pressing right mouse button inside canvas
            if leg is not None:

                start, end = leg.get_window_extent().get_points()
                x_px, y_px = event.x, event.y

                if start[0] <= x_px <= end[0] and start[1] <= y_px <= end[1]:
                    return 1

                else:
                    self.zoom_out(self)

            else:
                self.zoom_out(self)

        elif event.button == 1:
            # Action on pressing left mouse button inside canvas
            if leg is not None:

                start, end = leg.get_window_extent().get_points()
                x_px, y_px = event.x, event.y

                if start[0] <= x_px <= end[0] and start[1] <= y_px <= end[1]:
                    return 1

            self.box_start = None

            if event.inaxes:
                self.holddown = True
                self.box_start = event.xdata, event.ydata
                self.box_start_px = event.x, event.y

        elif event.button == 2:
            # Action on pressing mouse wheel button inside canvas
            pass

        return 0

    def zoom(self, *_):
        """
        """

        if not self.zoom_flag:
            self.maxlimits = self.get_limits()[:-2]
            self.zoom_flag = True
            self.autoscale_old = [
                self.get_autoscalex(),
                self.get_autoscaley1()
            ]

        self.zoomlimits = [
            min([self.box_start[0], self.box_end[0]]),
            max([self.box_start[0], self.box_end[0]]),
            min([self.box_start[1], self.box_end[1]]),
            max([self.box_start[1], self.box_end[1]])
        ]

        self.axis.set_xlim(self.zoomlimits[:2])
        self.axis.set_ylim(self.zoomlimits[2:])

        if self.xtime:
            self.set_xtime()

        self.canvas.draw_idle()

        return 0

    def button_released(self, event):
        """
        """

        if event.button == 1:

            if self.holddown:

                if self.zoom_box is not None:
                    self.zoom_box.remove()

                if self.box_start_px[0] != event.x and self.box_start_px[
                        1] != event.y:
                    self.zoom(self)

                self.holddown = False
                self.box_start = None
                self.box_start_px = None
                self.box_end = None
                self.zoom_box = None

        return 0

    def mouse_scroll(self, event):
        """
        Cave: Something wrong with the aspect ratio!
        """
        zoom_perc = 0.05

        if not event.inaxes:
            return 0

        x = event.xdata
        y = event.ydata

        lim = self.get_limits()

        # first zoom - set maxlimits
        if not self.zoom_flag:
            self.maxlimits = self.get_limits()[:-2]
            self.zoom_flag = True
            self.autoscale_old = [
                self.get_autoscalex(),
                self.get_autoscaley1()
            ]

        # Zoom out
        if event.step < 0:

            zoom_perc = 1 + zoom_perc

        # Zoom in
        elif event.step > 0:

            zoom_perc = 1 - zoom_perc

        xmax = x + zoom_perc * (lim[1] - x)
        xmin = x + zoom_perc * (lim[0] - x)
        ymax = y + zoom_perc * (lim[3] - y)
        ymin = y + zoom_perc * (lim[2] - y)

        self.axis.set_xlim([xmin, xmax])
        self.axis.set_ylim([ymin, ymax])

        if self.xtime:
            self.set_xtime()

        self.canvas.draw_idle()

        return 0

    def plot(self, *args, **kwargs):
        """
        """
        if "show_in_legend" in kwargs:

            self.append_legend_entries_flag(kwargs["show_in_legend"])
            kwargs.pop("show_in_legend")

        else:

            self.append_legend_entries_flag(True)

        # If there are no previous plots, get the marker size from saved main_plot_settings
        plots, dummy = self.axis.get_legend_handles_labels()

        if len(plots) == 0:
            kwargs["linewidth"] = self.main_plot_settings["adj"]["line_width"]
            kwargs["markersize"] = self.main_plot_settings["adj"][
                "marker_size"]

        # Handle secondary axis
        if "secondary_axis" in kwargs:

            if kwargs["secondary_axis"]:
                kwargs.pop("secondary_axis")

                self.unhide_secondary_axis()

                return self.secondary_axis.plot(*args, **kwargs)

            kwargs.pop("secondary_axis")

        return self.axis.plot(*args, **kwargs)

    def scatter(self, *args, **kwargs):
        """
        """
        if "show_in_legend" in kwargs:

            self.append_legend_entries_flag(kwargs["show_in_legend"])
            kwargs.pop("show_in_legend")

        else:

            self.append_legend_entries_flag(True)

        # If there are no previous plots, get the marker size from saved main_plot_settings
        plots, dummy = self.axis.get_legend_handles_labels()

        if len(plots) == 0:
            kwargs["s"] = self.main_plot_settings["adj"]["marker_size"]

        if "secondary_axis" in kwargs:

            if kwargs["secondary_axis"]:
                kwargs.pop("secondary_axis")

                self.unhide_secondary_axis()
                return self.secondary_axis.scatter(*args, **kwargs)

            kwargs.pop("secondary_axis")

        return self.axis.scatter(*args, **kwargs)

    def toggle_legend(self):
        """
        """
        leg = self.get_legend()

        if leg is not None:

            if leg.get_visible():

                leg.set_visible(False)

            else:

                leg.set_visible(True)

        else:

            self.set_legend(flag=True)

        self.draw_idle()

        return 0

    def reset_secondary_axes(self):
        """
        """
        print(self.figure.axes)

        self.figure.delaxes(self.secondary_axis)
        self.figure.delaxes(self.dummy_axis)

        # --- The secondary y-axis (hidden by default)
        self.secondary_axis = self.axis.twinx()
        self.secondary_axis.set_frame_on(False)
        self.show_secondary_axis = True
        self.hide_secondary_axis()

        # --- A dummy secondary x axis for mousover events
        self.dummy_axis = self.axis.twiny()
        self.dummy_axis.set_frame_on(False)
        self.dummy_axis.set_xticks([])
        self.set_dummy_xlim()

        print(self.figure.axes)

        return 0

    def set_settings(self, settings=None):
        """
        Sets the settings according to the values in Settings-array.
        """
        if settings is not None:

            # Get the containers for the different values
            ent = settings["entry"]
            adj = settings["adj"]
            self.main_plot_settings = settings

            # Set the plots' settings (line width, colors etc.)
            self.set_linewidth(adj["line_width"])
            self.set_markersize(adj["marker_size"])

            # Set the Canvas (axes, fontsizes, legend, etc.) settings
            self.set_title(ent["title"], adj["title_fs"])
            self.set_xlabel(ent["x_label"], adj["label_fs"])
            self.set_ylabel(ent["y1_label"], adj["label_fs"])
            self.set_secondary_ylabel(ent["y2_label"], adj["label_fs"])

            self.set_ticksize(adj["major_ticks"])

            # Set the plot labels (this is where the legend entries are stored)
            handles1, labels1 = self.axis.get_legend_handles_labels()
            handles2, labels2 = self.secondary_axis.get_legend_handles_labels()

            handles = handles1 + handles2
            # labels = labels1 + labels2

            if "legend_entries_flags" in settings:
                self.show_legend_entries = settings["legend_entries_flags"]

            if "legend_entries" in settings:

                for handle, label in zip(handles, settings["legend_entries"]):
                    handle.set_label(label)

                self.set_legend(settings["legend_global_flag"],
                                settings["legend_entries"], adj["legend_fs"],
                                adj["legend_ms"], adj["legend_lw"])

            if settings["autoscalex"]:

                self.set_autoscale(True, "x")

            else:

                self.set_xlim(ent["x_min"], ent["x_max"])

            if settings["autoscaley1"]:

                self.set_autoscale(True, "y")

            else:

                self.set_ylim(ent["y1_min"], ent["y1_max"])

            if settings["autoscaley2"]:

                self.set_secondary_autoscale(True, "y")

            else:

                self.set_secondary_ylim(ent["y2_min"], ent["y2_max"])

            if self.xtime:
                self.set_xtime()

            if not self.show_secondary_axis:
                self.hide_secondary_axis()

            self.canvas.draw_idle()

        return 0

    def set_legend(self, flag=False, entries=None, fs=18, ms=1, lw=1):
        """
        Handles the creation and destruction of the legend object
        flag......whether or not the legend is to be displayed
        entries...the actual legend entries
        ms........the legend marker size
        fs........the legend entries' font size
        lw........the legend linewidth
        """
        del entries

        handles1, labels1 = self.axis.get_legend_handles_labels()
        handles2, labels2 = self.secondary_axis.get_legend_handles_labels()

        handles = np.array([handles1 + handles2])[0]
        labels = np.array([labels1 + labels2])[0]

        handles = handles[self.show_legend_entries]
        labels = labels[self.show_legend_entries]

        prop = FontProperties(size=fs)

        self.figure.sca(self.axis)

        leg = self.dummy_axis.legend(handles,
                                     labels,
                                     loc=3,
                                     markerscale=ms,
                                     prop=prop,
                                     scatterpoints=1,
                                     numpoints=1)

        leg.draggable()

        for entry in leg.legendHandles:
            entry.set_linewidth(lw)

        if not (flag and len(labels) > 0):
            leg.set_visible(False)

        return 0

    def set_ticksize(self, ticksize=20):
        """
        Set the major tick-size of all axes
        """
        for tick in self.axis.xaxis.get_major_ticks():
            tick.label1.set_fontsize(ticksize)

        for tick in self.axis.yaxis.get_major_ticks():
            tick.label1.set_fontsize(ticksize)

        for tick in self.secondary_axis.yaxis.get_major_ticks():
            tick.label2.set_fontsize(ticksize)

        # If there is a colorbar, set the size of those ticks as well
        if self.cb is not None:
            self.cb.ax.tick_params(labelsize=ticksize)

        return 0

    def set_title(self, title="Title", fontsize=None):

        if fontsize is None:
            fontsize = self.get_titlesize()

        return self.axis.set_title(title, fontsize=fontsize)

    def set_xlabel(self, label="x", fontsize=None):

        if fontsize is None:
            fontsize = self.get_xaxis_labelsize()

        self.axis.set_xlabel(label, fontsize=fontsize)

        return 0

    def set_ylabel(self, label="y", fontsize=None):

        if fontsize is None:
            fontsize = self.get_yaxis_labelsize()

        self.axis.set_ylabel(label, fontsize=fontsize)

        return 0

    def set_yscale(self, scale):
        """
        Set scale of primary y_axis
        scale can be 'linear', 'log', 'symlog'
        """

        self.axis.set_yscale(scale)

        return 0

    def set_secondary_ylabel(self, label="secondary y", fontsize=None):

        if fontsize is None:
            fontsize = self.get_yaxis_labelsize()

        self.secondary_axis.set_ylabel(label, fontsize=fontsize)

        return 0

    def set_xlim(self, xmin, xmax):

        self.axis.set_xlim(xmin, xmax)
        self.set_dummy_xlim()

        return 0

    def set_ylim(self, ymin, ymax):

        self.axis.set_ylim(ymin, ymax)

        return 0

    def set_dummy_xlim(self):

        limits = self.get_limits()
        self.dummy_axis.set_xlim(limits[0], limits[1])

        return 0

    def set_secondary_ylim(self, ymin, ymax):

        self.secondary_axis.set_ylim(ymin, ymax)

        return 0

    def set_aspect(self, aspect='equal'):

        self.axis.set_aspect(aspect)
        # self.secondary_axis.set_aspect(aspect)
        # self.dummy_axis.set_aspect(aspect)

        return 0

    def set_linewidth(self, lw=1):
        """
        """
        plots1, dummy = self.axis.get_legend_handles_labels()
        plots2, dummy = self.secondary_axis.get_legend_handles_labels()

        for plot in plots1 + plots2:
            plot.set_lw(lw)

        return 0

    def set_scatter_markersize(self, artist, new_ms, axis=1):
        """
        """
        prop = ArtistInspector(artist).properties()

        if len(prop["edgecolors"]) > 0:

            edgecolor = prop["edgecolors"][0]

        else:

            edgecolor = []

        facecolor = prop["facecolors"][0]
        # linewidth = prop["linewidths"][0]
        label = prop["label"]

        x = np.array(prop["offsets"])[:, 0]
        y = np.array(prop["offsets"])[:, 1]

        artist.remove()

        if axis == 1:

            self.axis.scatter(x,
                              y,
                              s=new_ms,
                              c=facecolor,
                              edgecolor=edgecolor,
                              label=label)

        elif axis == 2:

            self.secondary_axis.scatter(x,
                                        y,
                                        s=new_ms,
                                        c=facecolor,
                                        edgecolor=edgecolor,
                                        label=label)

        return 0

    def set_markersize(self, ms=18):
        """
        """
        plots1, dummy = self.axis.get_legend_handles_labels()
        plots2, dummy = self.secondary_axis.get_legend_handles_labels()

        for plot in plots1:

            if type(plot) == matplotlib.lines.Line2D:

                plot.set_ms(ms)

            else:

                self.set_scatter_markersize(plot, ms, axis=1)

        for plot in plots2:

            if type(plot) == matplotlib.lines.Line2D:

                plot.set_ms(ms)

            else:

                self.set_scatter_markersize(plot, ms, axis=2)

        return 0

    def set_autoscale(self, autoscale=True, axis="both", tight=None):

        self.axis.autoscale(autoscale, axis, tight)

        return 0

    def set_secondary_autoscale(self, autoscale=True, axis="y", tight=None):

        self.secondary_axis.autoscale(autoscale, axis, tight)

        return 0

    def set_xtime(self):
        """
        Assumes the x axis values are given in s and changes the axis labels to
        have 5 tickmarks in hh:mm:ss format rather than in s.
        """

        xmin, xmax, ymin, ymax, y_sec_min, y_sec_max = self.get_limits()

        xvalues = np.linspace(xmin, xmax, 5)

        xlabels = []

        for val in xvalues:
            xlabels.append(self.timestr(val))

        self.axis.set_xticks(xvalues)
        self.axis.set_xticklabels(xlabels)

        return 0

    def get_autoscalex(self):
        """
        """
        return self.axis.get_autoscalex_on()

    def get_autoscaley1(self):
        """
        """
        return self.axis.get_autoscaley_on()

    def get_autoscaley2(self):
        """
        """
        return self.secondary_axis.get_autoscaley_on()

    def get_scale(self):
        """
        """
        return self.axis.yaxis.get_scale()

    def get_settings(self):
        """
        Gets the settings from the figure object and stores them in the
        Settings-array
        """
        limits = self.get_limits()
        labels = self.get_axis_labels()
        legend = self.get_legend()

        if legend is None:

            legend_global_flag = False

        else:

            legend_global_flag = legend.get_visible()

        handles1, legend_labels1 = self.axis.get_legend_handles_labels()
        handles2, legend_labels2 = self.secondary_axis.get_legend_handles_labels(
        )

        # handles = handles1 + handles2
        legend_labels = legend_labels1 + legend_labels2

        entries = {
            "title": self.get_title(),
            "x_label": labels[0],
            "x_min": limits[0],
            "x_max": limits[1],
            "y1_label": labels[1],
            "y1_min": limits[2],
            "y1_max": limits[3],
            "y2_label": labels[2],
            "y2_min": limits[4],
            "y2_max": limits[5]
        }

        adjustments = {
            "major_ticks": self.get_xaxis_ticksize(),
            "label_fs": self.get_xaxis_labelsize(),
            "legend_ms": self.get_legend_ms(),
            "legend_fs": self.get_legend_fs(),
            "legend_lw": self.get_legend_lw(),
            "title_fs": self.get_titlesize(),
            "line_width": self.get_linewidth(),
            "marker_size": self.get_markersize()
        }

        settings = {
            "entry": entries,
            "adj": adjustments,
            "legend_entries": legend_labels,
            "legend_global_flag": legend_global_flag,
            "legend_entries_flags": self.show_legend_entries,
            "autoscalex": self.get_autoscalex(),
            "autoscaley1": self.get_autoscaley1(),
            "autoscaley2": self.get_autoscaley2(),
            "xtime": self.xtime
        }

        return settings

    def get_limits(self):
        """
        returns the limits in the format [xmin, xmax, ymin, ymax]
        """

        lim = list(self.axis.get_xlim())
        lim.extend(list(self.axis.get_ylim()))
        lim.extend(list(self.secondary_axis.get_ylim()))

        return lim

    def get_title(self):
        """
        Returns the current plot title
        """

        return self.axis.get_title()

    def get_titlesize(self):
        """
        """

        return self.axis.title.get_size()

    def get_axis_labels(self):
        """
        Returns the current axis labels (x,y)
        """

        return self.axis.xaxis.get_label_text(), \
            self.axis.yaxis.get_label_text(), \
            self.secondary_axis.yaxis.get_label_text()

    def get_xaxis_labelsize(self):
        """
        Returns the current axis labels font size (x)
        """

        return self.axis.xaxis.get_label().get_size()

    def get_yaxis_labelsize(self):
        """
        Returns the current axis labels font size (y)
        """

        return self.axis.yaxis.get_label().get_size()

    def get_xaxis_ticksize(self):
        """
        Get the major tick-size of xaxis
        """

        return self.axis.xaxis.get_major_ticks()[0].label1.get_fontsize()

    def get_yaxis_ticksize(self):
        """
        Get the major tick-size of yaxis
        """

        return self.axis.yaxis.get_major_ticks()[0].label1.get_fontsize()

    def get_legend(self):
        """
        """

        return self.dummy_axis.get_legend()

    def get_legend_ms(self):
        """
        Only line2D plots and scatter plots (as collections) are considered for
        the legend
        """
        legend = self.get_legend()

        ms = 1.0

        if legend is not None and len(legend.get_texts()) > 0:

            for handle in legend.legendHandles:

                if type(handle) == PathCollection:
                    ms = np.sqrt(handle.get_sizes()[0] / self.get_markersize())

        return ms

    def get_legend_lw(self):
        """
        Only line2D plots and scatter plots (as collections) are considered for
        the legend
        """
        legend = self.get_legend()

        lw = 1.0

        if legend is not None and len(legend.get_texts()) > 0:

            for handle in legend.legendHandles:

                if type(handle) == matplotlib.lines.Line2D:
                    lw = handle.get_linewidth()

        return lw

    def get_legend_fs(self):
        """
        """
        legend = self.get_legend()

        if legend is not None and len(legend.get_texts()) > 0:

            legend_fs = legend.get_texts()[0].get_size()
            return legend_fs

        else:

            return 18

    def get_markersize(self):
        """
        """

        plots, dummy = self.axis.get_legend_handles_labels()

        if len(plots) == 0:
            return self.main_plot_settings["adj"]["marker_size"]

        plot = plots[0]

        if type(plot) == matplotlib.lines.Line2D:

            ms = plot.get_ms()

        else:

            ms = plot.get_sizes()[0]

        return ms

    def get_linewidth(self):
        """
        """
        plots, dummy = self.axis.get_legend_handles_labels()

        if len(plots) == 0:
            return self.main_plot_settings["adj"]["line_width"]

        plot = plots[0]

        if type(plot) == matplotlib.lines.Line2D:

            lw = plot.get_linewidth()

        else:

            lw = plot.get_linewidth()[0]

        return lw

    def draw_idle(self):

        self.canvas.draw_idle()

        if self.xtime:
            self.set_xtime()

        return 0

    def clear(self):

        self.axis.clear()
        self.secondary_axis.clear()
        self.hide_secondary_axis()

        return 0

    def __init__(self, main_window=None, nbp=0):

        Gtk.VBox.__init__(self)

        self._main_window = main_window

        self.figure = Figure(facecolor='#474747', edgecolor='#474747')
        self.canvas = FigureCanvasGTK3Agg(self.figure)  # a Gtk.DrawingArea

        # --- Flag for primary axis time format
        self.xtime = False

        # --- The primary axis
        self.axis = self.figure.add_subplot(111)

        # --- The secondary y-axis (hidden by default)
        self.secondary_axis = self.axis.twinx()
        self.secondary_axis.set_frame_on(False)
        self.show_secondary_axis = True
        self.hide_secondary_axis()

        # --- A dummy secondary x axis for mousover events
        self.dummy_axis = self.axis.twiny()
        self.dummy_axis.set_frame_on(False)
        self.dummy_axis.set_xticks([])
        self.set_dummy_xlim()

        self.cb = None
        self.nbp = nbp

        # Init some globals for getting the mouse position in the client
        # application
        self.inaxes = False
        self.mouse_x = 0
        self.mouse_y = 0

        hbox = Gtk.HBox(spacing=4)

        self.r_label = Gtk.Label("0.0 | ")
        self.z_label = Gtk.Label("| 0.0")
        self.spacer = Gtk.Label("|")
        self.reset_b = Gtk.Button("Reset Zoom")
        self.settings_b = Gtk.Button("Settings")

        self.pack_start(self.canvas, True, True, 0.0)
        self.pack_start(hbox, False, False, 0.0)
        hbox.pack_end(self.r_label, False, False, 0.0)
        hbox.pack_end(self.spacer, False, False, 0.0)
        hbox.pack_end(self.z_label, False, False, 0.0)
        hbox.pack_end(self.reset_b, False, False, 0.0)
        hbox.pack_start(self.settings_b, False, False, 0.0)

        # --- Variables for zooming --- #
        self.holddown = False
        self.zoom_flag = False
        self.box_start = None
        self.box_start_px = None
        self.box_end = None
        self.zoom_box = None
        self.zoomlimits = [0., 1., 0., 1.]
        self.maxlimits = []
        self.autoscale_old = [True, True]

        # --- Container for flags to show certain legend entries -- #
        self.show_legend_entries = np.array([], 'bool')

        # --- Connections --- #
        self.canvas.mpl_connect("motion_notify_event", self.mouse_move)
        self.canvas.mpl_connect("scroll_event", self.mouse_scroll)
        self.canvas.mpl_connect("button_press_event", self.button_pressed)
        self.canvas.mpl_connect("button_release_event", self.button_released)
        self.reset_b.connect("clicked", self.zoom_out)
        self.settings_b.connect("clicked", self.change_settings)
        self.canvas.mpl_connect("axes_enter_event", self.axes_enter_callback)
        self.canvas.mpl_connect("axes_leave_event", self.axes_leave_callback)
        self.canvas.mpl_connect("draw_event", self.draw_event)

        self.set_settings(generate_default_settings())

        # Create a plotsettings dictionary as a copy of the current plot settings
        # Redundand, but a quick workaround for keeping the marker size and linewidth
        # even if settings were loaded without any plots
        self.main_plot_settings = self.get_settings()
Ejemplo n.º 31
0
class AC_Yregion(Analysis_Cursor):
    Type = 'Y-Region'
    Prefix = 'Y'

    def __init__(self, name, colour='black'):
        super().__init__(name)
        self.y1 = 0
        self.y2 = 0.1
        self._fill = '/'
        self.colour = colour

    @property
    def Summary(self):
        return f'[{self.y1,self.y2}]'

    @property
    def SymbolFill(self):
        return self._fill

    @SymbolFill.setter
    def SymbolFill(self, new_hatch):
        self._fill = new_hatch

    def prepare_plot(self, pltfrm, ax):
        super().prepare_plot(pltfrm, ax)
        self.rect = Rectangle((0.0, self.y1),
                              0.1,
                              0.1,
                              angle=0.0,
                              facecolor='none',
                              edgecolor=self.colour,
                              hatch=self._fill)
        self.ax.add_patch(self.rect)

    def reset_cursor(self):
        lims = self.pltfrm.get_data_limits()
        if self.y1 < lims[2]:
            self.y1 = lims[2]
        if self.y1 > lims[3]:
            self.y1 = lims[3]
        if self.y2 < lims[2]:
            self.y2 = lims[2]
        if self.y2 > lims[3]:
            self.y2 = lims[3]

        if abs(self.y1 - self.y2) / (lims[3] - lims[2]) < 0.005:
            if abs(lims[2] - self.y1) > abs(lims[3] - self.y1):
                self.y2 = (lims[2] + self.y2) * 0.5
            else:
                self.y2 = (lims[3] + self.y2) * 0.5

    def delete_from_plot(self):
        if self.rect:
            self.rect.remove()

    def render_blit(self):
        if self.ax:
            lims = self.pltfrm.get_data_limits()
            self.rect.set_xy((lims[0], self.y1))
            self.rect.set_height(self.y2 - self.y1)
            self.rect.set_width(lims[1] - lims[0])
            self.ax.draw_artist(self.rect)

    def event_drag(self, coord):
        if self._is_drag == 'y1':
            self.y1 = coord[1]
        elif self._is_drag == 'y2':
            self.y2 = coord[1]

    def _event_mouse_pressed(self, mouse_coord):
        if self._pixel_distance(
                mouse_coord,
            (mouse_coord[0], self.y1))[1] < self._drag_threshold:
            self._is_drag = 'y1'
        elif self._pixel_distance(
                mouse_coord,
            (mouse_coord[0], self.y2))[1] < self._drag_threshold:
            self._is_drag = 'y2'
        else:
            self._is_drag = 'None'
Ejemplo n.º 32
0
class identificationWidget(QMainWindow):
    def __init__(self):
        super().__init__()
        self.peakThreshold = 0.01
        self.peakNumber = 50
        self.peakDistance = 2
        self.pickDistance = 1
        self.wavelengthPixelList = pd.DataFrame({
            'Wavelength': [' '],
            'Pixel': [' ']
        })
        self.selfPeakPixs = []
        self.standardPeakWavelengths = []
        self.matchedPeakPixs = []
        self.isPressed = False
        self.isPicked = False
        self.isMatchFinished = True
        self.standardSpectrum = []
        self.selfSpectrum = []
        self.currentPickedPeakWavelength = 0
        self.selfFWHM = 2
        self.REIDYStep = 2
        self.selfImageY = [0, 0]
        self.selfData = []
        self.reidentificationWidget = reIdentificationWidget(
            matchList=self.wavelengthPixelList,
            flux=self.selfData,
            FWHM=self.selfFWHM)

        self.initUI()

    def initUI(self):
        self.layout = QGridLayout()
        self.mainWidget = QWidget()

        self.calibrationFileOpenAction = QAction('CalibrationFileOpen', self)
        self.calibrationFileOpenAction.setShortcut('Ctrl+C')
        self.calibrationFileOpenAction.triggered.connect(
            self.onCalibnationFileOpen)
        self.calibrationFileOpenAction.setStatusTip('Open calibration image')

        #직접 찍은 아이덴티피케이션 이미지의 스펙트럼을 보여주는 fig

        self.selfSpectrumCanvas = FigureCanvas(Figure(figsize=(13, 5)))

        self.selfSpectrumCanvas.figure.clear()

        self.peakNumberSlider = QSlider(Qt.Horizontal, self)
        self.peakNumberSlider.setValue(self.peakNumber)
        self.peakNumberSlider.setRange(1, 100)
        self.peakDistanceSlider = QSlider(Qt.Horizontal, self)
        self.peakDistanceSlider.setValue(self.peakDistance)
        self.peakNumberSlider.setRange(1, 10)
        self.peakThresholdSlider = QSlider(Qt.Horizontal, self)
        self.peakThresholdSlider.setValue(int(self.peakThreshold * 100))
        self.peakNumberSlider.setRange(1, 100)

        self.peakNumberLabel = QLabel(f'Number of Peak = {self.peakNumber}')
        self.peakDistanceLabel = QLabel(
            f'Distance between Peak = {self.peakDistance}')
        self.peakThresholdLabel = QLabel(
            f'Threshold of peak = {self.peakThreshold}')

        self.peakNumberSlider.valueChanged.connect(
            self.onPeakNumberValueChanged)
        self.peakDistanceSlider.valueChanged.connect(
            self.onPeakDistanceValueChanged)
        self.peakThresholdSlider.valueChanged.connect(
            self.onPeakThresholdValueChanged)

        self.selfPeakControl = QWidget()
        self.peakControlLayout = QVBoxLayout()
        self.peakControlLayout.addWidget(self.peakNumberLabel)
        self.peakControlLayout.addWidget(self.peakNumberSlider)
        self.peakControlLayout.addWidget(self.peakDistanceLabel)
        self.peakControlLayout.addWidget(self.peakDistanceSlider)
        self.peakControlLayout.addWidget(self.peakThresholdLabel)
        self.peakControlLayout.addWidget(self.peakThresholdSlider)

        self.selfPeakControl.setLayout(self.peakControlLayout)

        # 직접 찍은 아이덴티피케이션 이미지를 보여주는 fig

        self.selfImageCanvas = FigureCanvas(Figure(figsize=(5, 2)))

        self.selfImageCanvas.mpl_connect("button_press_event",
                                         self.onPressAtImage)
        self.selfImageCanvas.mpl_connect("motion_notify_event",
                                         self.onMoveAtImage)
        self.selfImageCanvas.mpl_connect("button_release_event",
                                         self.onReleaseAtImage)

        self.selfSpectrumCanvas.mpl_connect('scroll_event',
                                            self.onScrollAtSelfSpectrum)
        self.selfSpectrumCanvas.mpl_connect('pick_event',
                                            self.onPickPeakAtSelfSpectrum)
        self.selfSpectrumCanvas.mpl_connect("button_press_event",
                                            self.onPressAtSelfSpectrum)
        self.selfSpectrumCanvas.mpl_connect("motion_notify_event",
                                            self.onMoveAtSelfSpectrum)
        self.selfSpectrumCanvas.mpl_connect("button_release_event",
                                            self.onReleaseAtSelfSpectrum)

        self.selfSpectrumGaussFitCanvas = FigureCanvas(Figure(figsize=(7, 7)))
        self.gaussFitWidget = QWidget()
        self.gaussFitLayout = QVBoxLayout()
        self.gaussFitButton = QPushButton('&Yes')
        self.gaussFitButton.clicked.connect(self.onGaussFitButtonClicked)
        self.FWHMSlider = QSlider(Qt.Horizontal, self)
        self.FWHMSlider.setValue(self.selfFWHM * 10)
        self.FWHMSlider.setRange(1, 100)

        self.FWHMLabel = QLabel(f'FHWM for comp image = {self.selfFWHM}')
        self.FWHMSlider.valueChanged.connect(self.onFWHMChanged)
        self.gaussFitLayout.addWidget(self.selfSpectrumGaussFitCanvas)
        self.gaussFitLayout.addWidget(self.FWHMSlider)
        self.gaussFitLayout.addWidget(self.FWHMLabel)
        self.gaussFitLayout.addWidget(self.gaussFitButton)
        self.gaussFitWidget.setLayout(self.gaussFitLayout)

        self.NeonArcButton = QPushButton('&Neon')
        self.NeonArcButton.clicked.connect(self.neonSpectrumDraw)
        self.OpenArcButton = QPushButton('&Open')

        self.standardSpectrumButtonLayout = QVBoxLayout()
        self.standardSpectrumButtonLayout.addWidget(self.NeonArcButton)
        self.standardSpectrumButtonLayout.addWidget(self.OpenArcButton)

        self.standardSpectrumButton = QWidget()
        self.standardSpectrumButton.setLayout(
            self.standardSpectrumButtonLayout)

        #비교할 아이덴티피케이션의 스펙트럼을 보여주는 fig

        self.standardSpectrumCanvas = FigureCanvas(Figure(figsize=(13, 5)))

        self.standardSpectrumCanvas.mpl_connect(
            'scroll_event', self.onScrollAtStandardSpectrum)
        self.standardSpectrumCanvas.mpl_connect(
            'pick_event', self.onPickPeakAtStandardSpectrum)
        self.standardSpectrumCanvas.mpl_connect('button_press_event',
                                                self.onPressAtStandardSpectrum)

        self.wavelengthPixelTable = QTableView()
        self.wavelengthPixelModel = tableModel(self.wavelengthPixelList)
        self.wavelengthPixelTable.setModel(self.wavelengthPixelModel)
        self.wavelengthPixelTable.setSelectionBehavior(QTableView.SelectRows)
        self.wavelengthPixelTable.doubleClicked.connect(
            self.onWavelengthPixelTableDoubleClicked)

        self.gaussButton = QPushButton('&GuassFit')
        self.gaussButton.clicked.connect(self.selfSpectrumDrawWithGauss)

        self.matchButton = QPushButton('&Match')
        self.matchButton.clicked.connect(self.onMatch)
        self.abortButton = QPushButton('&Abort')
        self.abortButton.clicked.connect(self.onAbort)
        self.exportButton = QPushButton('&Export')
        self.exportButton.clicked.connect(self.onExport)
        self.importButton = QPushButton('&Import')
        self.importButton.clicked.connect(self.onImport)

        self.tableMatchingButtons = QWidget()
        self.tableMatchingButtonLayout = QVBoxLayout()
        self.tableMatchingButtonLayout.addWidget(self.matchButton)
        self.tableMatchingButtonLayout.addWidget(self.abortButton)
        self.tableMatchingButtonLayout.addWidget(self.exportButton)
        self.tableMatchingButtonLayout.addWidget(self.importButton)
        self.tableMatchingButtons.setLayout(self.tableMatchingButtonLayout)

        self.setCentralWidget(self.mainWidget)
        menubar = self.menuBar()
        menubar.setNativeMenuBar(False)

        filemenu = menubar.addMenu('&File')  #&는 File을 Alt F로 실행하게 해준다
        filemenu.addAction(self.calibrationFileOpenAction)

        self.splitter = QSplitter(Qt.Horizontal)
        self.tables = QWidget()
        self.tableLayout = QVBoxLayout()
        self.tableLayout.addWidget(self.gaussButton)
        self.tableLayout.addWidget(self.wavelengthPixelTable)
        self.tableLayout.addWidget(self.tableMatchingButtons)
        self.tables.setLayout(self.tableLayout)

        self.splitter.addWidget(self.tables)
        self.spectrums = QWidget()
        self.spectrumsLayout = QVBoxLayout()
        self.spectrumsLayout.addWidget(self.selfSpectrumCanvas)
        self.spectrumsLayout.addWidget(self.standardSpectrumCanvas)
        self.spectrums.setLayout(self.spectrumsLayout)

        self.splitter.addWidget(self.spectrums)

        self.reidentificationBtn = QPushButton('&Reidentification')
        self.reidentificationBtn.clicked.connect(self.onReidentification)

        self.layout.addWidget(self.splitter, 1, 0, 3, 1)
        self.layout.addWidget(self.selfImageCanvas, 1, 1, 1, 1)
        self.layout.addWidget(self.selfPeakControl, 2, 1, 1, 1)
        self.layout.addWidget(self.standardSpectrumButton, 3, 1)
        self.layout.addWidget(self.reidentificationBtn, 4, 0, 1, -1)

        self.mainWidget.setLayout(self.layout)
        self.setCentralWidget(self.mainWidget)
        self.resize(1500, 800)
        self.center()

    '''
    칼리브레이션 파일(comp 파일)을 열고 Identification에 사용될 Y방향(Wavelength에 수직한 방향) 구간을 결정하는 메소드    
    '''

    def onCalibnationFileOpen(self):
        filePath = QFileDialog.getOpenFileName(
            self, 'Open calibration file',
            './Spectroscopy_Example/20181023/combine/')[0]
        hdr, data = openFitData(filePath)
        self.selfImageCanvas.figure.clear()
        self.selfImageAx = self.selfImageCanvas.figure.add_subplot(111)
        zimshow(self.selfImageAx, data)
        self.selfImageCanvas.draw()

        self.imageWidth = int(data.shape[1])
        self.selfData = data

    '''
    칼리브레이션파일 스펙트럼에서 Peak을 찾는 과정에 관여하는 3가지 initial value(number of peaks, distance btw peaks,
    threshol of peaks)를 조정해서 적절히 Peak을 찾을 수 있게 하는 메소드 
    Slider의 값을 받아서 canvas에 적용한다.
    '''

    def onPeakNumberValueChanged(self, val):
        self.peakNumber = val
        self.peakNumberLabel.setText(f'Number of Peak = {self.peakNumber}')
        self.selfSpectrumDraw(
            ymin=self.selfImageY[0],
            ymax=self.selfImageY[1],
            data=self.selfData,
            args=[self.peakDistance, self.peakThreshold, self.peakNumber])

    def onPeakDistanceValueChanged(self, val):
        self.peakDistance = val
        self.peakDistanceLabel.setText(
            f'Distance between Peak = {self.peakDistance}')
        self.selfSpectrumDraw(
            ymin=self.selfImageY[0],
            ymax=self.selfImageY[1],
            data=self.selfData,
            args=[self.peakDistance, self.peakThreshold, self.peakNumber])

    def onPeakThresholdValueChanged(self, val):
        self.peakThreshold = val / 100
        self.peakThresholdLabel.setText(
            f'Threshold of peak = {self.peakThreshold}')
        self.selfSpectrumDraw(
            ymin=self.selfImageY[0],
            ymax=self.selfImageY[1],
            data=self.selfData,
            args=[self.peakDistance, self.peakThreshold, self.peakNumber])

    '''
    칼리브레이션 스펙트럼 그래프를  확대/축소하는 메소드. 
    스크롤을 내리면 마우스 위치를 중심으로 xlim이 4/5배가 되고,
    스크롤을 올리면 마우스 위치를 중심으로 xlim이 5/4배가 된다.
    '''

    def onScrollAtSelfSpectrum(self, event):
        xmin, xmax = self.selfSpectrumAx.get_xlim()
        xnow = event.xdata
        if (event.button == 'up'):
            xsize = int((xmax - xmin) * 0.40)
            xmin = xnow - xsize
            xmax = xnow + xsize
            self.selfSpectrumAx.set_xlim(xmin, xmax)
            self.selfSpectrumAx.figure.canvas.draw()
        elif (event.button == 'down'):
            xsize = int((xmax - xmin) * 0.625)
            xmin = xnow - xsize
            xmax = xnow + xsize
            self.selfSpectrumAx.set_xlim(xmin, xmax)
            self.selfSpectrumAx.figure.canvas.draw()

    '''
    칼리브레이션 스펙트럼 그래프에서 pickPeak 메소드를 통해 생성된 peakPicker를 움직이고 그 값을 table에 저장하는 메소드
    두 가지 방식으로 peakPicker를 움직일 수 있다. 
    1. Drag and Drop :
        peakPicker를 클릭한 채로 끌어서 움직일 수 있고 마우스를 놓으면 위치가 고정된다. 
        peak 근처 peakDistance 픽셀에서는 자동으로 peak에 붙고 이때 색깔이 연두색으로 바뀐다. 
        마우스가 이동할때 그 픽셀값이 저장된다. 
    2. Double Click :
        selfSpectrum의 Peak의 text를 더블클릭하면 peakPicker 그 text로 이동하고 그 픽셀값이 저장된다. 
    '''

    def onPickPeakAtSelfSpectrum(self, event):
        if self.isMatchFinished: return
        if self.isPicked: return
        if (event.mouseevent.dblclick and event.artist != self.peakPicker):
            val = round(float(event.artist.get_text()), 4)
            self.wavelengthPixelList.loc[self.wavelengthPixelList.Wavelength ==
                                         self.currentPickedPeakWavelength,
                                         'Pixel'] = val
            self.peakPicker.remove()
            self.peakPicker = self.selfSpectrumAx.axvline(
                val, color='green', picker=True, pickradius=self.pickDistance)
            self.selfSpectrumAx.figure.canvas.draw()
            self.onChangedList()
            return
        if not event.mouseevent.button == 1: return
        self.isPicked = True

    def onMoveAtSelfSpectrum(self, event):
        if not event.inaxes: return
        if event.inaxes != self.selfSpectrumAx: return
        if not self.isPicked: return
        self.peakPicker.remove()
        dist = np.min(np.abs(self.selfPeakPixs - event.xdata))
        val = self.selfPeakPixs[np.argmin(
            np.abs(self.selfPeakPixs - event.xdata))]

        if (dist < self.peakDistance):
            self.peakPicker = self.selfSpectrumAx.axvline(
                val, color='green', picker=True, pickradius=self.pickDistance)
            self.wavelengthPixelList.loc[self.wavelengthPixelList.Wavelength ==
                                         self.currentPickedPeakWavelength,
                                         'Pixel'] = val

        else:
            self.peakPicker = self.selfSpectrumAx.axvline(
                event.xdata,
                color='blue',
                picker=True,
                pickradius=self.pickDistance)
            self.wavelengthPixelList.loc[self.wavelengthPixelList.Wavelength ==
                                         self.currentPickedPeakWavelength,
                                         'Pixel'] = int(event.xdata)

        self.selfSpectrumAx.figure.canvas.draw()
        self.onChangedList()

    def onReleaseAtSelfSpectrum(self, event):
        if not event.inaxes: return
        if event.inaxes != self.selfSpectrumAx: return
        if not self.isPicked: return
        self.isPicked = False

    '''
    칼리브레이션 스펙트럼 그래프에서 우클릭을 하면 pickPeak과정을 취소하는 메소드
    우클릭을 하면 self.onPickDisable을 호출한다.     
    '''

    def onPressAtSelfSpectrum(self, event):
        if (event.button == 3):
            self.onPickDisable()

    '''
    스탠다드 스펙트럼 그래프를 확대/축소하는 메소드. 
    스크롤을 내리면 마우스 위치를 중심으로 xlim이 4/5배가 되고,
    스크롤을 올리면 마우스 위치를 중심으로 xlim이 5/4배가 된다.
     '''

    def onScrollAtStandardSpectrum(self, event):
        xmin, xmax = self.standardSpectrumAx.get_xlim()
        xnow = event.xdata
        if (event.button == 'up'):
            xsize = int((xmax - xmin) * 0.40)
            xmin = xnow - xsize
            xmax = xnow + xsize
            self.standardSpectrumAx.set_xlim(xmin, xmax)
            self.standardSpectrumAx.figure.canvas.draw()
        elif (event.button == 'down'):
            xsize = int((xmax - xmin) * 0.625)
            xmin = xnow - xsize
            xmax = xnow + xsize
            self.standardSpectrumAx.set_xlim(xmin, xmax)
            self.standardSpectrumAx.figure.canvas.draw()

    '''
    pickPeak을 시작하기 위한 조건을 나타낸 메소드들. 
    standard spectrum 그래프에서 peak wavelength text를 더블클릭하거나 왼쪽 테이블에서 wavelength를 더블클릭하면 그
    wavelength에 맞는 pickPeak이 실행된다. 
    '''

    def onPickPeakAtStandardSpectrum(self, event):
        if (event.mouseevent.dblclick):
            self.pickPeak(event.artist.get_position()[0],
                          self.standardSpectrum, event.mouseevent)

    def onWavelengthPixelTableDoubleClicked(self, index):
        row = index.row()
        wavelength = self.standardPeakWavelengths[row]
        self.pickPeak(wavelength, self.standardSpectrum)

    '''
    peak wavelength에 맞는 peak Pixel을 찾기 위한 메소드
    선택된 wavelength와 그 peak의 axvline를 파란색으로 강조해서 보여주고 칼리브레이션 스펙트럼 그래프의 중간이나 그래프상의 
    같은 위치에 움직일수 있는 peakPicker를 생성해 해당 wavelength에 맞는 peak Pixel을 찾을 수 있도록 한다. 
    '''

    def pickPeak(self, waveNow, spectrum, mouse=None):
        if not self.isMatchFinished: return
        wavelength = spectrum[0]
        flux = spectrum[1]
        fluxNow = flux[np.where(wavelength == waveNow)][0]
        self.pickedPeak = self.standardSpectrumAx.axvline(waveNow,
                                                          color='blue')
        if (fluxNow + max(flux) / 2.85 > max(flux)):
            self.pickedText = self.standardSpectrumAx.text(
                waveNow,
                fluxNow + max(flux) / 2000,
                waveNow,
                c='blue',
                bbox=dict(facecolor='white', ec='none'))
        else:
            self.pickedText = self.standardSpectrumAx.text(
                waveNow,
                fluxNow + max(flux) / 2.85,
                waveNow,
                ha='center',
                va='center',
                rotation=90,
                clip_on=True,
                c='blue',
                bbox=dict(facecolor='white', ec='none'))
        if (mouse is None):
            xshift = [(self.selfSpectrumAx.get_xlim()[1] -
                       self.selfSpectrumAx.get_xlim()[0]) / 2 +
                      self.selfSpectrumAx.get_xlim()[0], 0]
        else:
            xshift = self.selfSpectrumAx.transData.inverted().transform(
                (mouse.x, 0))
        self.peakPicker = self.selfSpectrumAx.axvline(
            xshift[0], color='blue', picker=True, pickradius=self.pickDistance)
        self.selfSpectrumAx.figure.canvas.draw()
        self.standardSpectrumAx.figure.canvas.draw()
        self.currentPickedPeakWavelength = waveNow
        self.isMatchFinished = False

    '''
    pickPeak을 완료하기 위한 메소드 
    매치가 완료되었으면(onMatch) 해당 내용을 리스트에 저장하고 강제종료시(onAbort) 저장하지 않는다. 
    테이블 아래있는 버튼 (Match, Abort)을 클릭하거나 selfSpectrumFigure 위에서 키(m on Match, a on Abort)를 누르면 완료된다. 

    '''

    def keyPressEvent(self, event):
        if self.isMatchFinished: return
        elif event.key() == Qt.Key_M:
            self.onMatch()
        elif event.key() == Qt.Key_A:
            self.onAatch()

    def onMatch(self):
        print('match')
        self.onPickDisable()
        self.onChangedList()

    def onAbort(self):
        self.wavelengthPixelList.loc[self.wavelengthPixelList.Wavelength ==
                                     self.currentPickedPeakWavelength,
                                     'Pixel'] = 0
        self.onPickDisable()
        self.onChangedList()

    def onExport(self):
        path = QFileDialog.getSaveFileName(
            self, 'Choose save file location and name ', './',
            "CSV files (*.csv)")[0]
        self.wavelengthPixelList.to_csv(path, index=False)

    def onImport(self):
        file = QFileDialog.getOpenFileName(self, 'Choose match file', './',
                                           "CSV files (*.csv)")[0]
        self.wavelengthPixelList = pd.read_csv(file)
        self.onChangedList()

    def onPressAtStandardSpectrum(self, event):
        if (event.button == 3):
            self.onPickDisable()

    def onPickDisable(self):
        self.pickedPeak.remove()
        self.pickedText.remove()
        self.peakPicker.remove()

        self.selfSpectrumAx.figure.canvas.draw()
        self.standardSpectrumAx.figure.canvas.draw()
        self.isMatchFinished = True

    '''
    onCalibrationOpen 메소드로 열린 comp 이미지 파일에서 사용할 이미지의 y 축 범위를 찾는 메소드.
    마우스 클릭후 끌어서 범위를 결정하면 selfSpectrumDraw에서 
    '''

    def onPressAtImage(self, event):
        if not event.inaxes: return
        if event.inaxes != self.selfImageAx: return
        self.rect = Rectangle((0, 0), 1, 1, alpha=0.5)
        self.selfImageAx.add_patch(self.rect)
        self.x0 = event.xdata
        self.y0 = event.ydata
        self.isPressed = True

    def onMoveAtImage(self, event):
        if not event.inaxes: return
        if event.inaxes != self.selfImageAx: return
        if not self.isPressed: return
        self.x1 = event.xdata
        self.y1 = event.ydata
        self.rect.set_width(self.imageWidth)
        self.rect.set_height(self.y1 - self.y0)
        self.rect.set_xy((0, self.y0))
        self.selfImageAx.figure.canvas.draw()

    def onReleaseAtImage(self, event):
        if not event.inaxes: return
        if event.inaxes != self.selfImageAx: return
        if not self.isPressed: return
        y = int(self.rect.get_y())
        height = int(self.rect.get_height())
        self.rect.remove()
        self.selfImageAx.figure.canvas.draw()
        if (height < 0):
            height = 0 - height
        self.selfImageY = np.array([y, y + height])
        self.selfSpectrumDraw(
            ymin=y,
            ymax=y + height,
            data=self.selfData,
            args=[self.peakDistance, self.peakThreshold, self.peakNumber])
        self.isPressed = False

    def selfSpectrumDrawWithGauss(self):
        self.selfSpectrumGaussShow(self.selfSpectrum, self.selfPeakPixs)

    def selfSpectrumDraw(self, ymin, ymax, data, args):

        MINSEP_PK = args[0]  # minimum separation of peaks
        MINAMP_PK = args[
            1]  # fraction of minimum amplitude (wrt maximum) to regard as peak
        NMAX_PK = args[2]

        self.selfSpectrumCanvas.figure.clear()
        self.selfSpectrumAx = self.selfSpectrumCanvas.figure.add_subplot(111)
        identify = np.average(data[ymin:ymax, :], axis=0)
        ground = np.median(identify[0:200])
        max_intens = np.max(identify)
        peakPixs = peak_local_max(identify,
                                  indices=True,
                                  num_peaks=NMAX_PK,
                                  min_distance=MINSEP_PK,
                                  threshold_abs=max_intens * MINAMP_PK +
                                  ground)
        newPeakPixs = []
        for peakPix in peakPixs:
            newPeakPixs.append(peakPix[0])
        peakPixs = newPeakPixs

        self.selfPeakPixs = np.array(peakPixs)
        self.selfSpectrum = np.array(identify)

        for i in peakPixs:
            self.selfSpectrumAx.axvline(i,
                                        identify[i] / max(identify) + 0.0003,
                                        identify[i] / max(identify) + 0.2,
                                        color='c')
            if (identify[i] + max(identify) / 2.85 > max(identify)):
                self.selfSpectrumAx.text(i,
                                         identify[i] + max(identify) / 2000,
                                         str(i),
                                         clip_on=False,
                                         picker=self.pickDistance)
            else:
                self.selfSpectrumAx.text(i,
                                         identify[i] + max(identify) / 2.85,
                                         str(i),
                                         ha='center',
                                         va='center',
                                         rotation=90,
                                         clip_on=True,
                                         picker=self.pickDistance)

        self.selfSpectrumAx.plot(identify, color='r')
        self.selfSpectrumAx.set_xlim(0, len(identify))
        self.selfSpectrumAx.set_ylim(0, )
        self.selfSpectrumAx.figure.canvas.draw()

    # 가우스핏을 통해 peak의 정확한 픽셀값을 찾는다.
    # 값이 제일 큰 3개의 peak 스펙트럼과 그 가우스핏을 예시로 보여주고 특히 FWHM값을 모르거나 추측해야 할 경우
    # FWHM값을 변경하면서 가우스핏이 제대로 되었는지

    def selfSpectrumGaussShow(self, identify, peakPixs):
        self.gaussFitWidget.show()
        self.gaussFitWidget.raise_()
        self.selfSpectrumGaussDraw(identify, peakPixs)


#Todo 이거 잘 빼는 방법(바닥값에 맞게 잘 빼는 방법)을 찾아보자.

    def selfSpectrumGaussDraw(self, identify, peakPixs):
        iterations = 3

        fitter = LevMarLSQFitter()
        self.selfSpectrumGaussFitCanvas.figure.clear()
        identify = identify - np.median(identify[0:100])  ##여기!
        for i in np.arange(iterations):
            a = int(iterations / 5)
            if iterations % 5 != 0: a = a + 1

            ax = self.selfSpectrumGaussFitCanvas.figure.add_subplot(
                a, 5, i + 1)
            peakPix = peakPixs[-i]
            xs = np.arange(peakPix - int(self.selfFWHM) * 5,
                           peakPix + int(self.selfFWHM) * 5 + 1)

            g_init = Gaussian1D(amplitude=identify[peakPix],
                                mean=peakPix,
                                stddev=self.selfFWHM * gaussian_fwhm_to_sigma,
                                bounds={
                                    'amplitude': (0, 2 * identify[peakPix]),
                                    'mean': (peakPix - self.selfFWHM,
                                             peakPix + self.selfFWHM),
                                    'stddev': (0, self.selfFWHM)
                                })

            ax.set_ylim(0, max(identify) * 1.1)
            fitted = fitter(g_init, xs, identify[xs])
            ax.set_xlim(peakPix - fitted.stddev / gaussian_fwhm_to_sigma * 2,
                        peakPix + fitted.stddev / gaussian_fwhm_to_sigma * 2)
            ax.plot(xs, identify[xs], 'b')
            xss = np.arange(peakPix - self.selfFWHM * 5,
                            peakPix + self.selfFWHM * 5 + 1, 0.01)
            ax.plot(xss, fitted(xss), 'r--')
            ax.figure.canvas.draw()

    def selfSpectrumGaussFit(self, identify, peakPixs):

        self.selfSpectrumCanvas.figure.clear()
        self.selfSpectrumAx = self.selfSpectrumCanvas.figure.add_subplot(111)

        fitter = LevMarLSQFitter()

        sortedPeakPixs = np.sort(peakPixs)
        ground = np.median(identify[0:100])
        identify_fit = identify - ground  ## 여기도!

        peak_gauss = []
        i = 0
        x_identify = np.arange(len(identify_fit))
        for peakPix in peakPixs:
            g_init = Gaussian1D(amplitude=identify_fit[peakPix],
                                mean=peakPix,
                                stddev=self.selfFWHM * gaussian_fwhm_to_sigma,
                                bounds={
                                    'amplitude': (identify_fit[peakPix],
                                                  2 * identify_fit[peakPix]),
                                    'mean': (peakPix - self.selfFWHM,
                                             peakPix + self.selfFWHM),
                                    'stddev': (0, self.selfFWHM)
                                })
            fitted = fitter(g_init, x_identify, identify_fit)
            xss = np.arange(peakPix - int(self.selfFWHM) * 3,
                            peakPix + int(self.selfFWHM) * 3 + 1, 0.01)
            self.selfSpectrumAx.plot(xss, fitted(xss) + ground, 'royalblue')

            peak_gauss.append(fitted.mean.value)
            identify_fit = identify_fit - fitted(np.arange(len(identify)))
        '''
        while i < len(sortedPeakPixs)-1:
            peakPix = sortedPeakPixs[i]
            try:
                peakPix2 = sortedPeakPixs[i + 1]
            except:
                peakPix2 = int(peakPix + self.selfFWHM * 10)

            xs = np.arange(peakPix - int(self.selfFWHM) * 3, peakPix2 + int(self.selfFWHM) * 3 + 1)
            g_init = Gaussian1D(amplitude=identify_fit[peakPix],
                                mean=peakPix,
                                stddev=self.selfFWHM * gaussian_fwhm_to_sigma,
                                bounds={'amplitude': (0, 2 * identify_fit[peakPix]),
                                        'mean': (peakPix - self.selfFWHM, peakPix + self.selfFWHM),
                                        'stddev': (0, self.selfFWHM)}
                                )+\
                     Gaussian1D(amplitude=identify_fit[peakPix2],
                                mean=peakPix2,
                                stddev=self.selfFWHM * gaussian_fwhm_to_sigma,
                                bounds={'amplitude': (0, 2 * identify_fit[peakPix2]),
                                        'mean': (peakPix2 - self.selfFWHM, peakPix2 + self.selfFWHM),
                                        'stddev': (0, self.selfFWHM)}
                                )
            fitted = fitter(g_init, xs, identify_fit[xs])



            # fit 한 두 값이 mean 차이가  시그마의 합의 3배 보다 크면 1D fit
            if (fitted.mean_1.value - fitted.mean_0.value > 3* ( fitted.stddev_1 + fitted.stddev_0) ) :
                xs = np.arange(peakPix - int(self.selfFWHM) * 3, peakPix + int(self.selfFWHM) * 3 + 1)
                g_init = Gaussian1D(amplitude=identify_fit[peakPix],
                                    mean=peakPix,
                                    stddev=self.selfFWHM * gaussian_fwhm_to_sigma,
                                    bounds={'amplitude': (0, 2 * identify_fit[peakPix]),
                                            'mean': (peakPix - self.selfFWHM, peakPix + self.selfFWHM),
                                            'stddev': (0, self.selfFWHM)}
                                    )
                fitted = fitter(g_init, xs, identify_fit[xs])
                xss = np.arange(peakPix - int(self.selfFWHM) * 3, peakPix + int(self.selfFWHM) * 3 + 1, 0.01)
                self.selfSpectrumAx.plot(np.arange(len(identify)), fitted(np.arange(len(identify))) + ground, 'royalblue')
                peak_gauss.append(fitted.mean.value)
                i+=1
            else:
                peak_gauss.append(fitted.mean_0.value)
                peak_gauss.append(fitted.mean_1.value)
                xss = np.arange(peakPix - int(self.selfFWHM) * 3, peakPix2 + int(self.selfFWHM) * 3 + 1, 0.01)
                self.selfSpectrumAx.plot(xss, fitted(xss) + ground, 'yellowgreen')
                i += 2

        print(len(peak_gauss))
        print(len(sortedPeakPixs))
        '''
        peak_gauss = np.round(peak_gauss, 4)

        for i, j in zip(peakPixs, peak_gauss):
            self.selfSpectrumAx.axvline(j,
                                        identify[i] / max(identify) + 0.0003,
                                        identify[i] / max(identify) + 0.2,
                                        color='c')
            if (identify[i] + max(identify) / 2.85 > max(identify)):
                self.selfSpectrumAx.text(j,
                                         identify[i] + max(identify) / 2000,
                                         str(j),
                                         clip_on=False,
                                         picker=self.pickDistance)
            else:
                self.selfSpectrumAx.text(j,
                                         identify[i] + max(identify) / 2.85,
                                         str(j),
                                         ha='center',
                                         va='center',
                                         rotation=90,
                                         clip_on=True,
                                         picker=self.pickDistance)

        self.selfSpectrumAx.plot(identify, 'r--')
        self.selfSpectrumAx.set_xlim(0, len(identify))
        self.selfSpectrumAx.set_ylim(0, )
        self.selfPeakPixs = np.array(peak_gauss)
        self.selfSpectrumAx.figure.canvas.draw()
        self.gaussFitWidget.close()

    def onFWHMChanged(self, val):
        self.selfFWHM = val / 10
        self.FWHMLabel.setText(f'FHWM for comp image = {self.selfFWHM}')
        self.selfSpectrumGaussDraw(self.selfSpectrum, self.selfPeakPixs)

    def onGaussFitButtonClicked(self):
        self.selfSpectrumGaussFit(self.selfSpectrum, self.selfPeakPixs)

    def neonSpectrumDraw(self):
        filePath = './NeonArcSpectrum.fit'
        hdr, data = openFitData(filePath)
        self.standardSpectrumDraw(data=data, arc='Neon')

    def standardSpectrumDraw(self, data, arc, peaks=[]):
        wavelength = data[0]
        flux = data[1]
        self.standardSpectrumCanvas.figure.clear()
        self.standardSpectrumAx = self.standardSpectrumCanvas.figure.add_subplot(
            111)
        if (arc == 'Neon'):
            peaks = [
                5330.8000, 5400.5620, 5764.4180, 5852.4878, 5944.8342,
                6029.9971, 6074.3377, 6096.1630, 6143.0623, 6163.5939,
                6217.2813, 6266.4950, 6304.7892, 6334.4279, 6382.9914,
                6402.2460, 6506.5279, 6532.8824, 6598.9529, 6717.0428,
                6929.4680, 7032.4127, 7173.9390, 7245.1670, 7438.8990,
                7488.8720, 7535.7750, 8082.4580, 8377.6070
            ]

        for i in peaks:
            self.standardSpectrumAx.axvline(
                i,
                flux[np.where(i == wavelength)][0] / max(flux) + 0.003,
                flux[np.where(i == wavelength)] / max(flux) + 0.2,
                color='c')
            if (flux[np.where(i == wavelength)][0] + max(flux) / 2.85 >
                    max(flux)):
                self.standardSpectrumAx.text(
                    i,
                    flux[np.where(i == wavelength)][0] + max(flux) / 2000,
                    str(i),
                    clip_on=False,
                    picker=self.pickDistance)
            else:
                self.standardSpectrumAx.text(
                    i,
                    flux[np.where(i == wavelength)][0] + max(flux) / 2.85,
                    str(i),
                    ha='center',
                    va='center',
                    rotation=90,
                    clip_on=True,
                    picker=self.pickDistance)
        self.standardSpectrumAx.plot(wavelength, flux, 'r--')
        self.standardSpectrumAx.set_ylim(0, )
        self.standardSpectrumAx.set_xlim(min(wavelength), max(wavelength))
        self.standardPeakWavelengths = np.array(peaks)
        self.standardSpectrum = data
        self.matchedPeakPixs = np.zeros(
            (self.standardPeakWavelengths.shape[0]))
        matchInfo = np.column_stack(
            (self.standardPeakWavelengths, self.matchedPeakPixs))
        self.wavelengthPixelList = pd.DataFrame(
            matchInfo, columns=['Wavelength', 'Pixel'])
        self.onChangedList()
        self.standardSpectrumAx.figure.canvas.draw()

    def onReidentification(self):
        self.reidentificationWidget.setReidentifier(
            matchList=self.wavelengthPixelList,
            flux=self.selfData,
            fitMethod='linear',
            FWHM=self.selfFWHM)
        self.reidentificationWidget.show()
        self.reidentificationWidget.raise_()

    def onChangedList(self):

        self.wavelengthPixelModel = tableModel(self.wavelengthPixelList)
        self.wavelengthPixelTable.setModel(self.wavelengthPixelModel)

    def onButtonClicked(self, status):
        self.bottonSinal.emit(status)

    def center(self):
        qr = self.frameGeometry()
        cp = QDesktopWidget().availableGeometry().center()
        qr.moveCenter(cp)
        self.move(qr.topLeft())
Ejemplo n.º 33
0
class Viewer(tk.Frame):
    def __init__(self, parent, collection=None, with_toolbar=True):
        tk.Frame.__init__(self, parent)
        # toolbar
        if with_toolbar:
            self.create_toolbar()

        # canvas
        #canvas_frame = tk.Frame(self)
        #canvas_frame.pack(side=tk.LEFT,fill=tk.BOTH,expand=1)
        #title_frame = tk.Frame(canvas_frame)
        #title_frame.pack(side=tk.TOP,anchor=tk.NW)
        #tk.Label(title_frame,text=" Plot Title: ").pack(side=tk.LEFT)
        #self._title = tk.Entry(title_frame,width=30)
        #self._title.pack(side=tk.LEFT)
        #tk.Button(title_frame, text='Set', command=lambda: self.updateTitle()
        #        ).pack(side=tk.LEFT)

        self.fig = plt.Figure(figsize=(8, 6))
        self.ax = self.fig.add_subplot(111)
        self.canvas = FigureCanvasTkAgg(self.fig, master=self)
        self.setupMouseNavigation()
        self.navbar = ToolBar(self.canvas, self,
                              self.ax)  # for matplotlib features
        self.setupNavBarExtras(self.navbar)
        self.canvas.get_tk_widget().pack(side=tk.LEFT, fill=tk.BOTH, expand=1)
        # spectra list
        self.create_listbox()
        # toggle options
        self.mean = False
        self.median = False
        self.max = False
        self.min = False
        self.std = False

        self.spectrum_mode = False
        self.show_flagged = True
        # data
        self.collection = collection
        self.head = 0
        self.flag_filepath = os.path.abspath('./flagged_spectra.txt')
        if collection:
            self.update_artists(new_lim=True)
            self.update_list()

        # pack
        self.pack(fill=tk.BOTH, expand=1)
        self.color = '#000000'

    def returnToSelectMode(self):
        if self.ax.get_navigate_mode() == 'PAN':
            #Turn panning off
            self.navbar.pan()
        elif self.ax.get_navigate_mode() == 'ZOOM':
            #Turn zooming off
            self.navbar.zoom()

    def setupNavBarExtras(self, navbar):
        working_dir = os.path.dirname(os.path.abspath(__file__))
        img = Image.open(os.path.join(working_dir, "select.png"))
        #self.select_icon = tk.PhotoImage(file=os.path.join(working_dir,"select.png"))
        self.select_icon = tk.PhotoImage(img)

        self.select_button = tk.Button(navbar,
                                       width="24",
                                       height="24",
                                       image=img,
                                       command=self.returnToSelectMode).pack(
                                           side=tk.LEFT, anchor=tk.W)

        self.dirLbl = tk.Label(navbar, text="Viewing: None")
        self.dirLbl.pack(side=tk.LEFT, anchor=tk.W)

    def plotConfig(self):
        config = PlotConfigDialog(self,
                                  title=self.ax.get_title(),
                                  xlabel=self.ax.get_xlabel(),
                                  ylabel=self.ax.get_ylabel(),
                                  xlim=self.ax.get_xlim(),
                                  ylim=self.ax.get_ylim())
        if (config.applied):
            print(config.title)
            print(config.xlim)
            self.ax.set_title(config.title)
            self.ax.set_xlabel(config.xlabel)
            self.ax.set_ylabel(config.ylabel)
            self.ax.set_xlim(*config.xlim)
            self.ax.set_ylim(*config.ylim)
            self.canvas.draw()

    def rectangleStartEvent(self, event):
        self._rect = None
        self._rect_start = event

    def rectangleMoveEvent(self, event):
        try:
            dx = event.xdata - self._rect_start.xdata
            dy = event.ydata - self._rect_start.ydata
        except TypeError:
            #we're out of canvas bounds
            return

        if self._rect is not None:
            self._rect.remove()

        self._rect = Rectangle(
            (self._rect_start.xdata, self._rect_start.ydata),
            dx,
            dy,
            color='k',
            ls='--',
            lw=1,
            fill=False)
        self.ax.add_patch(self._rect)
        self.ax.draw_artist(self._rect)

    def rectangleEndEvent(self, event):
        if self._rect is not None:
            self._rect.remove()
        else:
            #make a small, fake rectangle
            class FakeEvent(object):
                def __init__(self, x, y):
                    self.xdata, self.ydata = x, y

            dy = (self.ax.get_ylim()[1] - self.ax.get_ylim()[0]) / 100.
            self._rect_start = FakeEvent(event.xdata - 10, event.ydata + dy)
            event = FakeEvent(event.xdata + 10, event.ydata - dy)

        if not self.collection is None:
            x0 = min(self._rect_start.xdata, event.xdata)
            x1 = max(self._rect_start.xdata, event.xdata)
            y0 = min(self._rect_start.ydata, event.ydata)
            y1 = max(self._rect_start.ydata, event.ydata)
            try:
                #if our data is sorted, we can easily isolate it
                x_data = self.collection.data.loc[x0:x1]
            except:
                #Pandas builtin throws an error, use another pandas builtin
                data = self.collection.data
                in_xrange = (data.index >= x0) & (data.index <= x1)
                x_data = data.iloc[in_xrange]

            ylim = sorted([self._rect_start.ydata, event.ydata])
            is_in_box = ((x_data > y0) & (x_data < y1)).any()

            highlighted = is_in_box.index[is_in_box].tolist()
            key_list = list(self.collection._spectra.keys())

            self.update_selected(highlighted)
            flags = self.collection.flags
            for highlight in highlighted:
                #O(n^2) woof
                if (not (highlight in flags)) or self.show_flagged:
                    pos = key_list.index(highlight)
                    self.listbox.selection_set(pos)

    def setupMouseNavigation(self):
        self.clicked = False
        self.select_mode = 'rectangle'
        self._bg_cache = None

        START_EVENTS = {'rectangle': self.rectangleStartEvent}

        MOVE_EVENTS = {'rectangle': self.rectangleMoveEvent}

        END_EVENTS = {'rectangle': self.rectangleEndEvent}

        def onMouseDown(event):
            if self.ax.get_navigate_mode() is None:
                self._bg_cache = self.canvas.copy_from_bbox(self.ax.bbox)
                self.clicked = True
                START_EVENTS[self.select_mode](event)

        def onMouseUp(event):
            if self.ax.get_navigate_mode() is None:
                self.canvas.restore_region(self._bg_cache)
                self.canvas.blit(self.ax.bbox)
                self.clicked = False
                END_EVENTS[self.select_mode](event)

        def onMouseMove(event):
            if self.ax.get_navigate_mode() is None:
                if (self.clicked):
                    self.canvas.restore_region(self._bg_cache)
                    MOVE_EVENTS[self.select_mode](event)
                    self.canvas.blit(self.ax.bbox)

        self.canvas.mpl_connect('button_press_event', onMouseDown)
        self.canvas.mpl_connect('button_release_event', onMouseUp)
        self.canvas.mpl_connect('motion_notify_event', onMouseMove)

    @property
    def head(self):
        return self._head

    @head.setter
    def head(self, value):
        if not hasattr(self, '_head'):
            self._head = 0
        else:
            self._head = value % len(self.collection)

    def set_head(self, value):
        if isinstance(value, Iterable):
            if len(value) > 0:
                value = value[0]
            else:
                value = 0
        self.head = value
        if self.spectrum_mode:
            self.update()
        self.update_selected()

    @property
    def collection(self):
        return self._collection

    @collection.setter
    def collection(self, value):
        if isinstance(value, Spectrum):
            # create new collection
            self._collection = Collection(name=Spectrum.name, spectra=[value])
        if isinstance(value, Collection):
            self._collection = value
        else:
            self._collection = None

    def move_selected_to_top(self):
        selected = self.listbox.curselection()
        keys = [self.collection.spectra[s].name for s in selected]
        for s in selected[::-1]:
            self.listbox.delete(s)
        self.listbox.insert(0, *keys)
        self.listbox.selection_set(0, len(keys))

    def unselect_all(self):
        self.listbox.selection_clear(0, tk.END)
        self.update_selected()

    def select_all(self):
        self.listbox.selection_set(0, tk.END)
        self.update_selected()

    def invert_selection(self):
        for i in range(self.listbox.size()):
            if self.listbox.selection_includes(i):
                self.listbox.selection_clear(i)
            else:
                self.listbox.selection_set(i)
        self.update_selected()

    def change_color(self):
        cpicker = ColorPickerDialog(self)
        #rgb,color = askcolor(self.color)
        if cpicker.applied:
            self.color = cpicker.color
            self.color_pick.config(bg=self.color)
            #update our list of chosen colors
            selected = self.listbox.curselection()
            selected_keys = [self.collection.spectra[s].name for s in selected]

            for key in selected_keys:
                self.colors[key] = self.color
            self.update()

    def select_by_name(self):
        pattern = self.name_filter.get()
        for i in range(self.listbox.size()):
            if pattern in self.listbox.get(i):
                self.listbox.selection_set(i)
            else:
                self.listbox.selection_clear(i)
        self.update_selected()

    def create_listbox(self):
        self._sbframe = tk.Frame(self)

        list_label = tk.Frame(self._sbframe)
        list_label.pack(side=tk.TOP, anchor=tk.N, fill=tk.X)
        tk.Label(list_label, text="Name:").pack(side=tk.LEFT, anchor=tk.W)
        self.name_filter = tk.Entry(list_label, width=14)
        self.name_filter.pack(side=tk.LEFT, anchor=tk.W)
        tk.Button(list_label,
                  text="Select",
                  command=lambda: self.select_by_name()).pack(side=tk.LEFT,
                                                              anchor=tk.W)
        self.sblabel = tk.Label(list_label, text="Showing: 0")
        self.sblabel.pack(side=tk.RIGHT)

        self.scrollbar = tk.Scrollbar(self._sbframe)
        self.listbox = tk.Listbox(self._sbframe,
                                  yscrollcommand=self.scrollbar.set,
                                  selectmode=tk.EXTENDED,
                                  width=30)
        self.scrollbar.config(command=self.listbox.yview)

        self.list_tools = tk.Frame(self._sbframe)
        tk.Button(self.list_tools,
                  text="To Top",
                  command=lambda: self.move_selected_to_top()).pack(
                      side=tk.TOP, anchor=tk.NW, fill=tk.X)
        tk.Button(self.list_tools,
                  text="Select All",
                  command=lambda: self.select_all()).pack(side=tk.TOP,
                                                          anchor=tk.NW,
                                                          fill=tk.X)
        tk.Button(self.list_tools,
                  text="Clear",
                  command=lambda: self.unselect_all()).pack(side=tk.TOP,
                                                            anchor=tk.NW,
                                                            fill=tk.X)
        tk.Button(self.list_tools,
                  text="Invert",
                  command=lambda: self.invert_selection()).pack(side=tk.TOP,
                                                                anchor=tk.NW,
                                                                fill=tk.X)

        self.color_field = tk.Frame(self.list_tools)
        tk.Label(self.color_field, text="Color:").pack(side=tk.LEFT)

        self.color_pick = tk.Button(self.color_field,
                                    text="",
                                    command=lambda: self.change_color(),
                                    bg='#000000')
        self.color_pick.pack(side=tk.RIGHT,
                             anchor=tk.NW,
                             fill=tk.X,
                             expand=True)

        self.color_field.pack(side=tk.TOP, anchor=tk.NW, fill=tk.X)

        self.list_tools.pack(side=tk.RIGHT, anchor=tk.NW)
        self.scrollbar.pack(side=tk.RIGHT, anchor=tk.E, fill=tk.Y)
        self.listbox.pack(side=tk.RIGHT, anchor=tk.E, fill=tk.Y)
        self.listbox.bind('<<ListboxSelect>>',
                          lambda x: self.set_head(self.listbox.curselection()))
        self._sbframe.pack(side=tk.RIGHT, anchor=tk.E, fill=tk.Y)

    def create_toolbar(self):
        self.toolbar = tk.Frame(self)
        tk.Button(self.toolbar, text='Read',
                  command=lambda: self.read_dir()).pack(side=tk.LEFT,
                                                        fill=tk.X,
                                                        expand=1)
        tk.Button(self.toolbar,
                  text='Mode',
                  command=lambda: self.toggle_mode()).pack(side=tk.LEFT,
                                                           fill=tk.X,
                                                           expand=1)
        tk.Button(self.toolbar,
                  text="Plot Config",
                  command=lambda: self.plotConfig()).pack(side=tk.LEFT,
                                                          fill=tk.X,
                                                          expand=1)
        tk.Button(self.toolbar,
                  text='Show/Hide Flagged',
                  command=lambda: self.toggle_show_flagged()).pack(
                      side=tk.LEFT, fill=tk.X, expand=1)
        tk.Button(self.toolbar,
                  text='Flag/Unflag',
                  command=lambda: self.toggle_flag()).pack(side=tk.LEFT,
                                                           fill=tk.X,
                                                           expand=1)
        tk.Button(self.toolbar,
                  text='Unflag all',
                  command=lambda: self.unflag_all()).pack(side=tk.LEFT,
                                                          fill=tk.X,
                                                          expand=1)
        #tk.Button(self.toolbar, text='Save Flag', command=lambda:
        #          self.save_flag()).pack(side=tk.LEFT,fill=tk.X,expand=1)
        tk.Button(self.toolbar,
                  text='Save Flags',
                  command=lambda: self.save_flag_as()).pack(side=tk.LEFT,
                                                            fill=tk.X,
                                                            expand=1)
        tk.Button(self.toolbar, text='Stitch',
                  command=lambda: self.stitch()).pack(side=tk.LEFT,
                                                      fill=tk.X,
                                                      expand=1)
        tk.Button(self.toolbar,
                  text='Jump_Correct',
                  command=lambda: self.jump_correct()).pack(side=tk.LEFT,
                                                            fill=tk.X,
                                                            expand=1)
        tk.Button(self.toolbar,
                  text='mean',
                  command=lambda: self.toggle_mean()).pack(side=tk.LEFT,
                                                           fill=tk.X,
                                                           expand=1)
        tk.Button(self.toolbar,
                  text='median',
                  command=lambda: self.toggle_median()).pack(side=tk.LEFT,
                                                             fill=tk.X,
                                                             expand=1)
        tk.Button(self.toolbar, text='max',
                  command=lambda: self.toggle_max()).pack(side=tk.LEFT,
                                                          fill=tk.X,
                                                          expand=1)
        tk.Button(self.toolbar, text='min',
                  command=lambda: self.toggle_min()).pack(side=tk.LEFT,
                                                          fill=tk.X,
                                                          expand=1)
        tk.Button(self.toolbar, text='std',
                  command=lambda: self.toggle_std()).pack(side=tk.LEFT,
                                                          fill=tk.X,
                                                          expand=1)
        self.toolbar.pack(side=tk.TOP, fill=tk.X)

    def updateTitle(self):
        print("Hello world!")
        self.ax.set_title(self._title.get())
        self.canvas.draw()

    def set_collection(self, collection):
        new_lim = True if self.collection is None else False
        self.collection = collection
        self.update_artists(new_lim=new_lim)
        self.update()
        self.update_list()

    def read_dir(self):
        try:
            directory = os.path.split(
                filedialog.askopenfilename(filetypes=(
                    ("Supported types", "*.asd *.sed *.sig *.pico"),
                    ("All files", "*"),
                )))[0]
        except:
            return
        if not directory:
            return
        c = Collection(name="collection", directory=directory)
        self.set_collection(c)
        self.dirLbl.config(text="Viewing: " + directory)

    def reset_stats(self):
        if self.mean_line:
            self.mean_line.remove()
            self.mean_line = None
            self.mean = False
        if self.median_line:
            self.median_line.remove()
            self.median_line = None
            self.median = False
        if self.max_line:
            self.max_line.remove()
            self.max_line = None
            self.max = False
        if self.min_line:
            self.min_line.remove()
            self.min_line = None
            self.min = False
        if self.std_line:
            self.std_line.remove()
            self.std_line = None
            self.std = False

    def toggle_mode(self):
        if self.spectrum_mode:
            self.spectrum_mode = False
        else:
            self.spectrum_mode = True
        self.update()

    def toggle_show_flagged(self):
        if self.show_flagged:
            self.show_flagged = False
        else:
            self.show_flagged = True
        self.update()

    def unflag_all(self):
        #new flags -> new statistics
        self.reset_stats()

        for spectrum in list(self.collection.flags):
            self.collection.unflag(spectrum)
        self.update()
        self.update_list()

    def toggle_flag(self):
        #new flags -> new statistics
        self.reset_stats()

        selected = self.listbox.curselection()
        keys = [self.listbox.get(s) for s in selected]

        for i, key in enumerate(keys):
            print(i, key)
            spectrum = key
            if spectrum in self.collection.flags:
                self.collection.unflag(spectrum)
                self.listbox.itemconfigure(selected[i], foreground='black')
            else:
                self.collection.flag(spectrum)
                self.listbox.itemconfigure(selected[i], foreground='red')
        # update figure
        self.update()

    def save_flag(self):
        ''' save flag to self.flag_filepath'''
        with open(self.flag_filepath, 'w') as f:
            for spectrum in self.collection.flags:
                print(spectrum, file=f)

    def save_flag_as(self):
        ''' modify self.flag_filepath and call save_flag()'''
        flag_filepath = filedialog.asksaveasfilename()
        if os.path.splitext(flag_filepath)[1] == '':
            flag_filepath = flag_filepath + '.txt'
        self.flag_filepath = flag_filepath
        self.save_flag()

    def update_list(self):
        self.listbox.delete(0, tk.END)
        for i, spectrum in enumerate(self.collection.spectra):
            self.listbox.insert(tk.END, spectrum.name)
            if spectrum.name in self.collection.flags:
                self.listbox.itemconfigure(i, foreground='red')
        self.update_selected()

    def ask_for_draw(self):
        #debounce canvas updates
        now = datetime.now()
        print(now - self.last_draw)
        if ((now - self.last_draw).total_seconds() > 0.5):
            self.canvas.draw()
            self.last_draw = now

    def update_artists(self, new_lim=False):
        if self.collection is None:
            return
        #update values being plotted -> redo statistics
        self.mean_line = None
        self.median_line = None
        self.max_line = None
        self.min_line = None
        self.std_line = None
        # save limits
        if new_lim == False:
            xlim = self.ax.get_xlim()
            ylim = self.ax.get_ylim()
        # plot
        self.ax.clear()
        # show statistics
        if self.spectrum_mode:
            idx = self.listbox.curselection()
            if len(idx) == 0:
                idx = [self.head]
            spectra = [self.collection.spectra[i] for i in idx]
            flags = [s.name in self.collection.flags for s in spectra]
            print("flags = ", flags)
            flag_style = ' '
            if self.show_flagged:
                flag_style = 'r'
            artists = Collection(name='selection', spectra=spectra).plot(
                ax=self.ax,
                style=list(np.where(flags, flag_style, self.color)),
                picker=1)
            self.ax.set_title('selection')
            # c = str(np.where(spectrum.name in self.collection.flags, 'r', 'k'))
            # spectrum.plot(ax=self.ax, label=spectrum.name, c=c)
        else:
            # red curves for flagged spectra
            flag_style = ' '
            if self.show_flagged:
                flag_style = 'r'
            flags = [
                s.name in self.collection.flags
                for s in self.collection.spectra
            ]
            print("flags = ", flags)
            self.collection.plot(ax=self.ax,
                                 style=list(np.where(flags, flag_style, 'k')),
                                 picker=1)
            #self.ax.set_title(self.collection.name)

        keys = [s.name for s in self.collection.spectra]
        artists = self.ax.lines
        self.artist_dict = {key: artist for key, artist in zip(keys, artists)}
        self.colors = {key: 'black' for key in keys}
        self.ax.legend().remove()
        self.navbar.setHome(self.ax.get_xlim(), self.ax.get_ylim())
        self.canvas.draw()
        self.sblabel.config(text="Showing: {}".format(len(artists)))

    def update_selected(self, to_add=None):
        """ Update, only on flaged"""
        if self.collection is None:
            return

        if to_add:
            for key in to_add:
                self.artist_dict[key].set_linestyle('--')
        else:
            keys = [s.name for s in self.collection.spectra]
            selected = self.listbox.curselection()
            selected_keys = [self.collection.spectra[s].name for s in selected]
            for key in keys:
                if key in selected_keys:
                    self.artist_dict[key].set_linestyle('--')
                else:
                    self.artist_dict[key].set_linestyle('-')
        self.canvas.draw()

    def update(self):
        """ Update the plot """
        if self.collection is None:
            return
        # show statistics
        if self.spectrum_mode:
            self.ax.clear()
            idx = self.listbox.curselection()
            if len(idx) == 0:
                idx = [self.head]
            spectra = [self.collection.spectra[i] for i in idx]
            flags = [s.name in self.collection.flags for s in spectra]
            print("flags = ", flags)
            flag_style = ' '
            if self.show_flagged:
                flag_style = 'r'
            Collection(name='selection', spectra=spectra).plot(
                ax=self.ax,
                style=list(np.where(flags, flag_style, 'k')),
                picker=1)
            self.ax.set_title('selection')
            # c = str(np.where(spectrum.name in self.collection.flags, 'r', 'k'))
            # spectrum.plot(ax=self.ax, label=spectrum.name, c=c)
        else:
            # red curves for flagged spectra

            keys = [s.name for s in self.collection.spectra]
            for key in keys:
                if key in self.collection.flags:
                    if self.show_flagged:
                        self.artist_dict[key].set_visible(True)
                        self.artist_dict[key].set_color('red')
                    else:
                        self.artist_dict[key].set_visible(False)
                else:
                    self.artist_dict[key].set_color(self.colors[key])
                    self.artist_dict[key].set_visible(True)

            if self.show_flagged:
                self.sblabel.config(
                    text="Showing: {}".format(len(self.artist_dict)))
            else:
                self.sblabel.config(text="Showing: {}".format(
                    len(self.artist_dict) - len(self.collection.flags)))
            '''
            self.collection.plot(ax=self.ax,
                                 style=list(np.where(flags, flag_style, 'k')),
                                 picker=1)
            self.ax.set_title(self.collection.name)
            '''

        if self.spectrum_mode:
            #self.ax.legend()
            pass
        else:
            #self.ax.legend().remove()
            pass
        self.ax.set_ylabel(self.collection.measure_type)
        #toggle appearance of statistics
        if self.mean_line != None: self.mean_line.set_visible(self.mean)
        if self.median_line != None: self.median_line.set_visible(self.median)
        if self.max_line != None: self.max_line.set_visible(self.max)
        if self.min_line != None: self.min_line.set_visible(self.min)
        if self.std_line != None: self.std_line.set_visible(self.std)
        self.canvas.draw()

    def next_spectrum(self):
        if not self.spectrum_mode:
            return
        self.head = (self.head + 1) % len(self.collection)
        self.update()

    def stitch(self):
        ''' 
        Known Bugs
        ----------
        Can't stitch one spectrum and plot the collection
        '''
        self.collection.stitch()
        self.update_artists()

    def jump_correct(self):
        ''' 
        Known Bugs
        ----------
        Only performs jump correction on 1000 and 1800 wvls and 1 reference
        '''
        self.collection.jump_correct([1000, 1800], 1)
        self.update_artists()

    def toggle_mean(self):
        if self.mean:
            self.mean = False

        else:
            self.mean = True
            if not self.mean_line:
                self.collection.mean().plot(ax=self.ax,
                                            c='b',
                                            label=self.collection.name +
                                            '_mean',
                                            lw=3)
                self.mean_line = self.ax.lines[-1]
        self.update()

    def toggle_median(self):
        if self.median:
            self.median = False
        else:
            self.median = True
            if not self.median_line:
                self.collection.median().plot(ax=self.ax,
                                              c='g',
                                              label=self.collection.name +
                                              '_median',
                                              lw=3)
                self.median_line = self.ax.lines[-1]
        self.update()

    def toggle_max(self):
        if self.max:
            self.max = False
        else:
            self.max = True
            if not self.max_line:
                self.collection.max().plot(ax=self.ax,
                                           c='y',
                                           label=self.collection.name + '_max',
                                           lw=3)
                self.max_line = self.ax.lines[-1]
        self.update()

    def toggle_min(self):
        if self.min:
            self.min = False
        else:
            self.min = True
            if not self.min_line:
                self.collection.min().plot(ax=self.ax,
                                           c='m',
                                           label=self.collection.name + '_min',
                                           lw=3)
                self.min_line = self.ax.lines[-1]
        self.update()

    def toggle_std(self):
        if self.std:
            self.std = False
        else:
            self.std = True
            if not self.std_line:
                self.collection.std().plot(ax=self.ax,
                                           c='c',
                                           label=self.collection.name + '_std',
                                           lw=3)
                self.std_line = self.ax.lines[-1]
        self.update()
Ejemplo n.º 34
0
class cropWidget(QWidget):
    cropDoneSignal = pyqtSignal(cropInfo)

    def __init__(self, currentFileLocation = ''):
        super().__init__()
        self.filename = currentFileLocation
        self.cropInfo = cropInfo()
        self.isPressed = False
        self.cropCheckWidget = cropCheckWidget(self.cropInfo)
        self.initUI()

    def initUI(self):
        self.hbox = QHBoxLayout()

        self.fig = plt.Figure()
        self.canvas = FigureCanvas(self.fig)
        self.ax = self.fig.add_subplot(111)

        self.canvas.mpl_connect("button_press_event", self.on_press)
        self.canvas.mpl_connect("motion_notify_event", self.on_move)
        self.canvas.mpl_connect("button_release_event", self.on_release)

        self.hbox.addWidget(self.canvas)
        self.setLayout(self.hbox)

        if (self.filename !=''):
            self.data = fits.open(Path(self.filename))[0].data
            zimshow(self.ax, self.data)
        self.canvas.draw()

    def setFileName(self, fileName):
        self.filename = fileName
        self.data = fits.open(Path(self.filename))[0].data
        zimshow(self.ax, self.data)
        self.canvas.draw()

    def on_press(self, event):
        if not event.inaxes : return
        if event.inaxes != self.ax: return
        self.rect = Rectangle((0, 0), 1, 1, alpha=0.5)
        self.ax.add_patch(self.rect)
        self.x0 = event.xdata
        self.y0 = event.ydata
        self.isPressed = True

    def on_move(self, event):
        if not event.inaxes : return
        if event.inaxes != self.ax: return
        if not self.isPressed : return

        self.x1 = event.xdata
        self.y1 = event.ydata
        self.rect.set_width(self.x1 - self.x0)
        self.rect.set_height(self.y1 - self.y0)
        self.rect.set_xy((self.x0, self.y0))
        self.ax.figure.canvas.draw()

    def on_release(self, event):
        if not event.inaxes : return
        if event.inaxes != self.ax: return
        if not self.isPressed: return
        x = int(self.rect.get_x())
        y = int(self.rect.get_y())
        width = int(self.rect.get_width())
        height = int(self.rect.get_height())

        x0 = x
        x1 = x + width
        y0 = y
        y1 = y + height
        if (x0 > x1):
            x0, x1 = x1, x0
        if (y0 > y1):
            y0, y1 = y1, y0

        self.cropInfo.x0 = x0
        self.cropInfo.x1 = x1
        self.cropInfo.y0 = y0
        self.cropInfo.y1 = y1
        self.cropInfo.filename = self.filename
        self.cropDoneSignal.emit(self.cropInfo)
        self.rect.remove()
        self.ax.figure.canvas.draw()
        self.isPressed = False
        self.cropCheckWidget.setCropInfo(self.cropInfo)
        self.cropCheckWidget.show()
        self.cropCheckWidget.raise_()
Ejemplo n.º 35
0
class ItemArtist:

  def __init__(self, position, state):
    self.position = position
    
    indx = state.positions.index(position)

    self.top = -state.tops[indx]
    self.top_line, = pylab.plot([0,width], 2*[self.top], c='b')

    self.bottom = -state.bottoms[indx]
    self.bottom_line, = pylab.plot([0,width], 2*[self.bottom], c='b')

    self.edge = -state.edges[indx]
    self.edge_line, = pylab.plot([0,width], 2*[self.edge], c='g')

    self.label = Text(width/2, (self.top+self.bottom)/2,
        str(position), va='center', ha='center')

    self.axes = pylab.gca()
    self.axes.add_artist(self.label)

    self.src_box = None
    self.exp_box = None
    self._check_boxes(state)


  def _check_boxes(self, state):

    if self.position == state.src:
      if self.src_box == None:
        self.src_box = Rectangle((0, self.bottom), width,
          self.top - self.bottom, fill=True, ec=None, fc='0.7')
        self.axes.add_patch(self.src_box)
      else:
        self.src_box.set_y(self.bottom)
        self.src_box.set_height(self.top - self.bottom)

    elif self.position == state.exp1:
      if state.exp1 < state.src:
        gap_bottom = self.top - state.exp1_gap
      else:
        gap_bottom = self.bottom

      if self.exp_box == None:
        self.exp_box = Rectangle((0,gap_bottom), width,
          state.exp1_gap, fill=True, ec=None, fc='0.7')
        self.axes.add_patch(self.exp_box)
      else:
        self.exp_box.set_y(gap_bottom)
        self.exp_box.set_height(state.exp1_gap)

    elif self.position == state.exp2:
      if state.exp2 < state.src:
        gap_bottom = self.top - state.exp2_gap
      else:
        gap_bottom = self.bottom

      if self.exp_box == None:
        self.exp_box = Rectangle((0,gap_bottom), width, state.exp2_gap,
          fill=True, ec=None, fc='0.7')
        self.axes.add_patch(self.exp_box)
      else:
        self.exp_box.set_y(gap_bottom)
        self.exp_box.set_height(state.exp2_gap)
    else:
      if self.src_box != None:
        self.src_box.remove()
        self.src_box = None
      if self.exp_box != None:
        self.exp_box.remove()
        self.exp_box = None


  def inState(self, state):
    return self.position in state.positions

  def update(self, position, state):
    moved = False

    if position != self.position:
      self.position = position
      self.label.set_text(str(position))

    indx = state.positions.index(self.position)

    old_top = self.top
    self.top = -state.tops[indx]
    if old_top != self.top:
      self.top_line.set_ydata(2*[self.top])
      moved = True

    old_bottom = self.bottom
    self.bottom = -state.bottoms[indx]
    if old_bottom != self.bottom:
      self.bottom_line.set_ydata(2*[self.bottom])
      moved = True

    old_edge = self.edge
    self.edge = -state.edges[indx]
    if old_edge != self.edge:
      self.edge_line.set_ydata(2*[self.edge])
    
    if moved:
      # adjust label, blank spot, etc.
      self.label.set_y((self.top + self.bottom)/2)
      self._check_boxes(state)

  def remove(self):
    self.edge_line.remove()
    self.top_line.remove()
    self.bottom_line.remove()
    self.label.remove()

    if self.src_box != None:
      self.src_box.remove()
    if self.exp_box != None:
      self.exp_box.remove()