Exemplo n.º 1
0
class MainFrame(wx.Frame):
    """Contains the main GUI and button boxes"""

    def __init__(self, parent,config):
        wx.Frame.__init__(self, parent, title="DeepLabCut2.0 - Labeling ToolBox", size=(1600, 980))

# Add SplitterWindow panels top for figure and bottom for buttons
        self.split_win = wx.SplitterWindow(self)
        # self.top_split = wx.Panel(self.split_win, style=wx.SUNKEN_BORDER)
        self.top_split = MatplotPanel(self.split_win,config) # This call/link the MatplotPanel and MainFrame classes which replaces the above line
        self.bottom_split = wx.Panel(self.split_win, style=wx.SUNKEN_BORDER)
        self.split_win.SplitHorizontally(self.top_split, self.bottom_split, 885)
        self.Maximize(True)

# Add Buttons to the bottom_split window and bind them to plot functions

        self.Button1 = wx.Button(self.bottom_split, -1, "Load Frames", size=(200, 40), pos=(250, 25))
        self.Button1.Bind(wx.EVT_BUTTON, self.browseDir)
        self.Button1.Enable(True)

        self.Button5 = wx.Button(self.bottom_split, -1, "Help", size=(80, 40), pos=(580, 25))
        self.Button5.Bind(wx.EVT_BUTTON, self.help)
        self.Button5.Enable(False)

        self.Button2 = wx.Button(self.bottom_split, -1, "Next Frame", size=(120, 40), pos=(800, 25))
        self.Button2.Bind(wx.EVT_BUTTON, self.nextImage)
        self.Button2.Enable(False)

        self.Button4 = wx.Button(self.bottom_split, -1, "Save", size=(80, 40), pos=(1050, 25))
        self.Button4.Bind(wx.EVT_BUTTON, self.save)
        self.Button4.Enable(False)
        self.close = wx.Button(self.bottom_split, -1, "Quit", size=(80, 40), pos=(1230, 25))
        self.close.Bind(wx.EVT_BUTTON,self.quitButton)

        self.currentDirectory = os.getcwd()
        self.index = []
        self.iter = []
        self.colormap = cm.hsv
        
        self.file = 0

        self.updatedCoords = []

        self.dataFrame = None
        self.flag = True
        self.file = 0
        self.config_file = config
        self.addLabel = wx.CheckBox(self.top_split, label = 'Add new labels to existing dataset?',pos = (80, 855))
        self.addLabel.Bind(wx.EVT_CHECKBOX,self.newLabel)
        self.new_labels = False
        
    def newLabel(self, event):
        self.chk = event.GetEventObject()
        if self.chk.GetValue() == True:
            self.new_labels = True
            self.addLabel.Enable(False)
        else:
            self.new_labels = False
#
    def quitButton(self, event):
        """
        Quits the GUI
        """
        self.Destroy()

    def help(self,event):
        """
        Opens Instructions
        """
        wx.MessageBox('1. Select one of the body parts from the radio buttons to add a label (if necessary change config.yaml first to edit the label names). \n\n2. Right clicking on the image will add the selected label. \n The label will be marked as circle filled with a unique color. \n\n3. Hover your mouse over this newly added label to see its name. \n\n4. Use left click and drag to move the label position. \n\n5. To change the marker size mark the checkbox and move the slider. \n Change the markersize only after finalizing the position of your first label, otherwise you will not be able to move your first label around! \n\n6. Once you are happy with the position, select another body part from the radio button. \n Be careful, once you add a new body part, you will not be able to move the old labels. \n\n7. Click Next Frame to move to the next image. \n\n8. When finished labeling all the images, click \'Save\' to save all the labels as a .h5 file. \n\n9. Click OK to continue using the labeling GUI.', 'User instructions', wx.OK | wx.ICON_INFORMATION)

    def onClick(self,event):
        x1 = event.xdata
        y1 = event.ydata
        self.drs = []
        normalize = mcolors.Normalize(vmin=np.min(self.colorparams), vmax=np.max(self.colorparams))
        if event.button == 3:
            if self.rdb.GetSelection() in self.buttonCounter :
                wx.MessageBox('%s is already annotated. \n Select another body part to annotate.' % (str(self.bodyparts[self.rdb.GetSelection()])), 'Error!', wx.OK | wx.ICON_ERROR)
            else:
                if self.flag == len(self.bodyparts):
                    wx.MessageBox('All body parts are annotated! Click \'Save\' to save the changes. \n Click OK to continue.', 'Done!', wx.OK | wx.ICON_INFORMATION)
                    self.canvas.mpl_disconnect(self.onClick)

                color = self.colormap(normalize(self.rdb.GetSelection()))
                circle = [patches.Circle((x1, y1), radius = self.markerSize, fc=color, alpha=0.5)]
                self.num.append(circle)
                self.ax1f1.add_patch(circle[0])
                self.dr = auxfun_drag_label.DraggablePoint(circle[0],self.bodyparts[self.rdb.GetSelection()])
                self.dr.connect()
                self.buttonCounter.append(self.rdb.GetSelection())
                self.dr.coords = [[x1,y1,self.bodyparts[self.rdb.GetSelection()],self.rdb.GetSelection()]]
                self.drs.append(self.dr)
                self.updatedCoords.append(self.dr.coords)
        self.canvas.mpl_disconnect(self.onClick)


    def browseDir(self, event):
        """
        Show the DirDialog and ask the user to change the directory where machine labels are stored
        """
        from skimage import io
        dlg = wx.DirDialog(self, "Choose the directory where your extracted frames are saved:",os.getcwd(), style = wx.DD_DEFAULT_STYLE)
        if dlg.ShowModal() == wx.ID_OK:
            self.dir = dlg.GetPath()
            self.Button1.Enable(False)
            self.Button2.Enable(True)
            self.Button5.Enable(True)
        else:
            dlg.Destroy()
            self.Close(True)
        dlg.Destroy()
        with open(str(self.config_file), 'r') as ymlfile:
            self.cfg = yaml.load(ymlfile)
        self.scorer = self.cfg['scorer']
        self.bodyparts = self.cfg['bodyparts']
        self.videos = self.cfg['video_sets'].keys()
        self.markerSize = self.cfg['dotsize']
        self.colormap = plt.get_cmap(self.cfg['colormap'])
        self.project_path=self.cfg['project_path']
        self.index = glob.glob(os.path.join(self.dir,'*.png'))
        
        self.relativeimagenames=self.index ##[n.split(self.project_path+'/')[1] for n in self.index]
        
        self.fig1, (self.ax1f1) = plt.subplots(figsize=(12, 7.8),facecolor = "None")
        self.iter = 0
        self.buttonCounter = []
        im = io.imread(self.index[self.iter])

        im_axis = self.ax1f1.imshow(im, self.colormap)

        img_name = Path(self.index[self.iter]).name # self.index[self.iter].split('/')[-1]
        self.ax1f1.set_title(str(str(self.iter)+"/"+str(len(self.index)-1) +" "+ img_name ))
        self.canvas = FigureCanvas(self.top_split,-1,self.fig1)
        #checks for unique bodyparts
        if len(self.bodyparts)!=len(set(self.bodyparts)):
          print("Error! bodyparts must have unique labels!Please choose unique bodyparts in config.yaml file and try again. Quiting for now!")
          self.Destroy()
          
        if self.new_labels == True:
          self.oldDF = pd.read_hdf(os.path.join(self.dir,'CollectedData_'+self.scorer+'.h5'),'df_with_missing')
          oldBodyParts = self.oldDF.columns.get_level_values(1)
          _, idx = np.unique(oldBodyParts, return_index=True)
          oldbodyparts2plot =  list(oldBodyParts[np.sort(idx)])
          self.bodyparts =  list(set(self.bodyparts) - set(oldbodyparts2plot))
          self.rdb = wx.RadioBox(self.top_split, id=1, label="Select a body part to annotate",pos=(1250, 65), choices=self.bodyparts, majorDimension =1,style=wx.RA_SPECIFY_COLS,validator=wx.DefaultValidator, name=wx.RadioBoxNameStr)
          self.option = self.rdb.Bind(wx.EVT_RADIOBOX,self.onRDB)
          cbar = self.fig1.colorbar(im_axis, ax = self.ax1f1)
          cbar.set_ticks(range(12,np.max(im),int(np.floor(np.max(im)/len(self.bodyparts)-1))))
          cbar.set_ticklabels(self.bodyparts)
        else:
          self.addLabel.Enable(False)
          cbar = self.fig1.colorbar(im_axis, ax = self.ax1f1)
          cbar.set_ticks(range(12,np.max(im),int(np.floor(np.max(im)/len(self.bodyparts)-1))))
          cbar.set_ticklabels(self.bodyparts)
          self.rdb = wx.RadioBox(self.top_split, id=1, label="Select a body part to annotate",pos=(1250, 65), choices=self.bodyparts, majorDimension =1,style=wx.RA_SPECIFY_COLS,validator=wx.DefaultValidator, name=wx.RadioBoxNameStr)
          self.option = self.rdb.Bind(wx.EVT_RADIOBOX,self.onRDB)


        self.cidClick = self.canvas.mpl_connect('button_press_event', self.onClick)
        self.flag = 0
        self.num = []
        self.counter = []
        self.presentCoords = []

        self.colorparams = list(range(0,len(self.bodyparts)+1))

        a = np.empty((len(self.index),2,))
        a[:] = np.nan
        for bodypart in self.bodyparts:
            index = pd.MultiIndex.from_product([[self.scorer], [bodypart], ['x', 'y']],names=['scorer', 'bodyparts', 'coords'])
            #frame = pd.DataFrame(a, columns = index, index = self.index)
            frame = pd.DataFrame(a, columns = index, index = self.relativeimagenames)
            self.dataFrame = pd.concat([self.dataFrame, frame],axis=1)

        if self.file == 0:
            self.checkBox = wx.CheckBox(self.top_split, label = 'Adjust marker size.',pos = (500, 855))
            self.checkBox.Bind(wx.EVT_CHECKBOX,self.onChecked)
            self.slider = wx.Slider(self.top_split, -1, 5, 0, 20,size=(200, -1),  pos=(500, 780),style=wx.SL_HORIZONTAL | wx.SL_AUTOTICKS | wx.SL_LABELS )
            self.slider.Bind(wx.EVT_SLIDER, self.OnSliderScroll)
            self.slider.Enable(False)

    def onRDB(self,event):
       self.option = self.rdb.GetSelection()
       self.counter.append(self.option)

    def nextImage(self,event):
        """
        Moves to next image
        """
        from skimage import io
        self.file = 1
        MainFrame.saveEachImage(self)
        self.canvas.Destroy()
        plt.close(self.fig1)
        self.ax1f1.clear()
        self.iter = self.iter + 1
        #Refreshing the button counter
        self.buttonCounter = []
        self.rdb.SetSelection(0)
        self.fig1, (self.ax1f1) = plt.subplots(figsize=(12, 7.8),facecolor = "None")

        # Checks for the last image and disables the Next button
        if len(self.index) - self.iter == 1:
            self.Button2.Enable(False)
            self.Button4.Enable(True)

        if len(self.index) > self.iter:
            self.updatedCoords = []
            #read the image
            im = io.imread(self.index[self.iter])
            #Plotting
            im_axis = self.ax1f1.imshow(im,self.colormap)
            cbar = self.fig1.colorbar(im_axis, ax = self.ax1f1)
            cbar.set_ticks(range(12,np.max(im),int(np.floor(np.max(im)/len(self.bodyparts)))))
            cbar.set_ticklabels(self.bodyparts)
            img_name = Path(self.index[self.iter]).name # self.index[self.iter].split('/')[-1]
            self.ax1f1.set_title(str(str(self.iter)+"/"+str(len(self.index)-1) +" "+ img_name ))
            self.canvas = FigureCanvas(self.top_split, -1, self.fig1)
            self.cidClick = self.canvas.mpl_connect('button_press_event', self.onClick)


    def saveEachImage(self):
        """
        Saves data for each image
        """
        plt.close(self.fig1)

        for idx, bp in enumerate(self.updatedCoords):
            #self.dataFrame.loc[self.index[self.iter]][self.scorer, bp[0][-2],'x' ] = bp[-1][0]
            #self.dataFrame.loc[self.index[self.iter]][self.scorer, bp[0][-2],'y' ] = bp[-1][1]
            self.dataFrame.loc[self.relativeimagenames[self.iter]][self.scorer, bp[0][-2],'x' ] = bp[-1][0]
            self.dataFrame.loc[self.relativeimagenames[self.iter]][self.scorer, bp[0][-2],'y' ] = bp[-1][1]

    def save(self,event):
        """
        Saves the final dataframe
        """
        MainFrame.saveEachImage(self)
        if self.new_labels == True:
            self.dataFrame = pd.concat([self.oldDF,self.dataFrame],axis=1)
        # Windows compatible
        self.dataFrame.to_csv(os.path.join(self.dir,"CollectedData_" + self.scorer + ".csv"))
        self.dataFrame.to_hdf(os.path.join(self.dir,"CollectedData_" + self.scorer + '.h5'),'df_with_missing',format='table', mode='w')

        nextFilemsg = wx.MessageBox('File saved. Do you want to label another data set?', 'Repeat?', wx.YES_NO | wx.ICON_INFORMATION)
        if nextFilemsg == 2:
            self.file = 1
            plt.close(self.fig1)
            self.canvas.Destroy()
            self.rdb.Destroy()
            self.buttonCounter = []
            self.updatedCoords = []
            self.dataFrame = None
            self.counter = []
            self.bodyparts = []
            self.Button1.Enable(True)
            self.slider.Enable(False)
            self.checkBox.Enable(False)
            self.new_labels = self.new_labels
            MainFrame.browseDir(self, event)
        else:
            self.Destroy()
            print("You can now check the labels, using 'check_labels' before proceeding. Then,  you can use the function 'create_training_dataset' to create the training dataset.")

    def onChecked(self, event):
      self.cb = event.GetEventObject()
      if self.cb.GetValue() == True:
          self.slider.Enable(True)
          self.cidClick = self.canvas.mpl_connect('button_press_event', self.onClick)
      else:
          self.slider.Enable(False)

    def OnSliderScroll(self, event):
        """
        Adjust marker size for plotting the annotations
        """
        from skimage import io
        self.drs = []
        plt.close(self.fig1)
        self.canvas.Destroy()
        self.fig1, (self.ax1f1) = plt.subplots(figsize=(12, 7.8),facecolor = "None")
        self.markerSize = (self.slider.GetValue())
        im = io.imread(self.index[self.iter])
        im_axis = self.ax1f1.imshow(im,self.colormap)
        cbar = self.fig1.colorbar(im_axis, ax = self.ax1f1)
        cbar.set_ticks(range(12,np.max(im),int(np.floor(np.max(im)/len(self.bodyparts)))))
        cbar.set_ticklabels(self.bodyparts)
        img_name = Path(self.index[self.iter]).name #self.index[self.iter].split('/')[-1]
        self.ax1f1.set_title(str(str(self.iter)+"/"+str(len(self.index)-1) +" "+ img_name ))
        self.canvas = FigureCanvas(self.top_split, -1, self.fig1)
        normalize = mcolors.Normalize(vmin=np.min(self.colorparams), vmax=np.max(self.colorparams))

        for idx, bp in enumerate(self.updatedCoords):
            col = self.updatedCoords[idx][-1][-1]
            color = self.colormap(normalize(col))
            x1 = self.updatedCoords[idx][-1][0]
            y1 = self.updatedCoords[idx][-1][1]
            circle = [patches.Circle((x1, y1), radius=self.markerSize, fc = color, alpha=0.5)]
            self.ax1f1.add_patch(circle[0])
            self.cidClick = self.canvas.mpl_connect('button_press_event', self.onClick)
Exemplo n.º 2
0
class MainFrame(wx.Frame):
    def __init__(self):
        """
        :return: the GUI
        """
        wx.Frame.__init__(self,
                          None,
                          title='Chemostrat: Sample Classifier',
                          size=(1500, 950))

        panel = wx.Panel(self,
                         style=wx.MINIMIZE_BOX | wx.SYSTEM_MENU | wx.CAPTION
                         | wx.CLOSE_BOX | wx.CLIP_CHILDREN)
        panel.SetBackgroundColour("white")

        self.Bind(wx.EVT_CLOSE, self.on_exit)

        self.sp_home = wx.SplitterWindow(panel)
        self.content = wx.Panel(self.sp_home, style=wx.SUNKEN_BORDER)
        self.console_bar = wx.Panel(self.sp_home, style=wx.SUNKEN_BORDER)
        self.content.SetBackgroundColour("White")
        self.console_bar.SetBackgroundColour("White")
        self.sp_home.SplitHorizontally(self.content, self.console_bar, 845)
        sizer1 = wx.BoxSizer(wx.VERTICAL)
        sizer1.Add(self.sp_home, 1, wx.EXPAND)
        panel.SetSizerAndFit(sizer1)

        notebook = wx.Notebook(self.content)
        content_sizer = wx.BoxSizer(wx.VERTICAL)
        content_sizer.Add(notebook, 1, wx.ALL | wx.EXPAND, 5)
        self.content.SetSizer(content_sizer)

        style = wx.TE_MULTILINE | wx.TE_READONLY | wx.VSCROLL
        self.log = wx.TextCtrl(self.console_bar,
                               wx.ID_ANY,
                               pos=(0, 0),
                               size=(1480, 50),
                               style=style)
        print(
            "Welcome to Chemostrat: Sample Classifier. Please Load a database to begin analysis."
        )
        log_sizer = wx.BoxSizer()
        log_sizer.Add(self.log, 1, wx.EXPAND)
        self.console_bar.SetSizerAndFit(log_sizer)

        # --------------------------------------------------------------------------------------------------------------

        tab_one = TabPanel(notebook)
        tab_one.SetBackgroundColour("White")
        notebook.AddPage(tab_one, "Data Import")
        self.sp1 = wx.SplitterWindow(tab_one)
        self.selections1 = wx.Panel(self.sp1, style=wx.SUNKEN_BORDER)
        self.graphing1 = wx.Panel(self.sp1, style=wx.SUNKEN_BORDER)
        self.selections1.SetBackgroundColour("White")
        self.graphing1.SetBackgroundColour("White")
        self.sp1.SplitVertically(self.selections1, self.graphing1, 250)
        sizer1 = wx.BoxSizer(wx.VERTICAL)
        sizer1.Add(self.sp1, 1, wx.EXPAND)
        tab_one.SetSizerAndFit(sizer1)

        load_button = wx.Button(self.selections1,
                                label='Load Data',
                                pos=(10, 10),
                                size=(225, 60))
        load_button.Bind(wx.EVT_BUTTON, self.on_open)

        wx.StaticText(self.selections1, label="Filter 1:", pos=(12, 85))
        wx.StaticText(self.selections1, label="Select Choices:", pos=(12, 140))
        self.filter_combo1 = wx.ComboBox(self.selections1,
                                         -1,
                                         choices=['No Data Loaded'],
                                         pos=(12, 110),
                                         size=(220, 20))
        self.filter_check1 = wx.CheckListBox(self.selections1,
                                             -1,
                                             pos=(12, 160),
                                             size=(220, 120))
        self.filter_combo1.Bind(wx.EVT_COMBOBOX, self.on_combo1)

        wx.StaticText(self.selections1, label="Filter 2:", pos=(12, 295))
        wx.StaticText(self.selections1, label="Select Choices:", pos=(12, 350))
        self.filter_combo2 = wx.ComboBox(self.selections1,
                                         -1,
                                         choices=['No Data Loaded'],
                                         pos=(12, 320),
                                         size=(220, 20))
        self.filter_check2 = wx.CheckListBox(self.selections1,
                                             -1,
                                             pos=(12, 370),
                                             size=(220, 120))
        self.filter_combo2.Bind(wx.EVT_COMBOBOX, self.on_combo2)

        wx.StaticText(self.selections1, label="Filter 3:", pos=(12, 505))
        wx.StaticText(self.selections1, label="Select Choices:", pos=(12, 560))
        self.filter_combo3 = wx.ComboBox(self.selections1,
                                         -1,
                                         choices=['No Data Loaded'],
                                         pos=(12, 530),
                                         size=(220, 20))
        self.filter_check3 = wx.CheckListBox(self.selections1,
                                             -1,
                                             pos=(12, 580),
                                             size=(220, 120))
        self.filter_combo3.Bind(wx.EVT_COMBOBOX, self.on_combo3)

        self.combo_boxes = [
            self.filter_combo1, self.filter_combo2, self.filter_combo3
        ]
        self.check_boxes = [
            self.filter_check1, self.filter_check2, self.filter_check3
        ]
        self.data_been_filtered = False

        filter_button = wx.Button(self.selections1,
                                  label='Filter Data',
                                  pos=(10, 715),
                                  size=(225, 70))
        filter_button.Bind(wx.EVT_BUTTON, self.on_filter)

        self.grid = gridlib.Grid(self.graphing1)
        self.grid.CreateGrid(0, 0)
        grid_sizer = wx.BoxSizer(wx.VERTICAL)
        grid_sizer.Add(self.grid, 1, wx.EXPAND)
        self.graphing1.SetSizerAndFit(grid_sizer)

        # --------------------------------------------------------------------------------------------------------------

        tab_two = TabPanel(notebook)
        tab_two.SetBackgroundColour("White")
        notebook.AddPage(tab_two, "Principal Component Analysis")
        self.sp2 = wx.SplitterWindow(tab_two)
        self.selections2 = wx.Panel(self.sp2, style=wx.SUNKEN_BORDER)
        self.graphing2 = wx.Panel(self.sp2, style=wx.SUNKEN_BORDER)
        self.selections2.SetBackgroundColour("White")
        self.graphing2.SetBackgroundColour("White")
        self.sp2.SplitVertically(self.selections2, self.graphing2, 250)
        sizer2 = wx.BoxSizer(wx.VERTICAL)
        sizer2.Add(self.sp2, 1, wx.EXPAND)
        tab_two.SetSizerAndFit(sizer2)

        wx.StaticText(self.selections2,
                      label="Select Data Labels:",
                      pos=(70, 10))
        wx.StaticText(self.selections2, label="Colour:", pos=(10, 37))
        wx.StaticText(self.selections2, label="Shape", pos=(10, 62))
        wx.StaticText(self.selections2, label="x:", pos=(10, 87))
        wx.StaticText(self.selections2, label="y:", pos=(10, 110))

        self.color_on_pca = wx.ComboBox(self.selections2,
                                        choices=['No '
                                                 'Data Loaded'],
                                        pos=(80, 35),
                                        size=(152, 20))
        self.shape_on_pca = wx.ComboBox(self.selections2,
                                        choices=['No '
                                                 'Data Loaded'],
                                        pos=(80, 60),
                                        size=(152, 20))
        self.color_on_pca.Bind(
            wx.EVT_TEXT,
            lambda event: self.pop_shape_frm_plots(event, 'pca_color'))
        self.shape_on_pca.Bind(
            wx.EVT_TEXT,
            lambda event: self.pop_shape_frm_plots(event, 'pca_shape'))
        self.x_axis_selection = wx.ComboBox(self.selections2,
                                            choices=['PC1', 'PC2', 'PC3'],
                                            value='PC1',
                                            pos=(80, 85),
                                            size=(152, 20))
        self.y_axis_selection = wx.ComboBox(self.selections2,
                                            choices=['PC1', 'PC2', 'PC3'],
                                            value='PC2',
                                            pos=(80, 110),
                                            size=(152, 20))

        self.confirm_btn_PCA = wx.Button(self.selections2,
                                         label="Generate "
                                         "Graph",
                                         pos=(10, 642),
                                         size=(225, 65))
        self.confirm_btn_PCA.Bind(wx.EVT_BUTTON, self.confirm_PCA)
        self.CLR_check = wx.CheckBox(self.selections2,
                                     label="Center "
                                     "Log "
                                     "Ratio?",
                                     pos=(10, 135))
        self.CLR_check.SetValue(True)
        self.PCA_button = wx.Button(self.selections2,
                                    label='Generate PDF',
                                    pos=(10, 720),
                                    size=(225, 65))
        self.PCA_button.Bind(wx.EVT_BUTTON, self.on_PCA)
        self.PCA_button.Enable(False)

        self.PCA_selection = wx.CheckListBox(self.selections2,
                                             -1,
                                             pos=(12, 175),
                                             size=(220, 270))
        self.PCA_selection.Bind(wx.EVT_COMBOBOX, self.on_combo3)
        wx.StaticText(self.selections2, label="Size:", pos=(10, 480))
        self.size_slider_PCA = wx.Slider(self.selections2,
                                         1,
                                         25,
                                         0,
                                         100,
                                         pos=(10, 500),
                                         size=(200, -1),
                                         style=wx.SL_HORIZONTAL
                                         | wx.SL_VALUE_LABEL)
        self.arrows_check = wx.CheckBox(self.selections2,
                                        label="Plot arrows?",
                                        pos=(10, 530))
        self.arrows_check.SetValue(True)
        self.samples_check = wx.CheckBox(self.selections2,
                                         label="Plot samples?",
                                         pos=(10, 560))
        self.samples_check.SetValue(True)

        wx.StaticText(self.selections2, label="Label Points", pos=(10, 592))
        self.label_points_pca = wx.ComboBox(self.selections2,
                                            choices=['No '
                                                     'Data Loaded'],
                                            pos=(80, 590),
                                            size=(152, 20))

        # --------------------------------------------------------------------------------------------------------------

        tab_three = TabPanel(notebook)
        tab_three.SetBackgroundColour("White")
        notebook.AddPage(tab_three, "Scatter")
        self.sp3 = wx.SplitterWindow(tab_three)
        self.selections3 = wx.Panel(self.sp3, style=wx.SUNKEN_BORDER)
        self.graphing3 = wx.Panel(self.sp3, style=wx.SUNKEN_BORDER)
        self.selections3.SetBackgroundColour("White")
        self.graphing3.SetBackgroundColour("White")
        self.sp3.SplitVertically(self.selections3, self.graphing3, 250)
        sizer3 = wx.BoxSizer(wx.VERTICAL)
        sizer3.Add(self.sp3, 1, wx.EXPAND)
        tab_three.SetSizerAndFit(sizer3)

        wx.StaticText(self.selections3,
                      label="Select Data Labels:",
                      pos=(70, 10))
        wx.StaticText(self.selections3, label="Colour:", pos=(10, 37))
        wx.StaticText(self.selections3, label="Shape:", pos=(10, 62))
        wx.StaticText(self.selections3, label="x:", pos=(10, 87))
        wx.StaticText(self.selections3, label="y:", pos=(10, 112))
        wx.StaticText(self.selections3, label="z:", pos=(10, 137))
        self.color_on_scatter = wx.ComboBox(self.selections3,
                                            choices=['No Data Loaded'],
                                            pos=(80, 35),
                                            size=(152, 20))
        self.shape_on_scatter = wx.ComboBox(self.selections3,
                                            choices=['No Data Loaded'],
                                            pos=(80, 60),
                                            size=(152, 20))
        self.color_on_scatter.Bind(
            wx.EVT_TEXT,
            lambda event: self.pop_shape_frm_plots(event, 'scatter_color'))
        self.shape_on_scatter.Bind(
            wx.EVT_TEXT,
            lambda event: self.pop_shape_frm_plots(event, 'scatter_shape'))

        self.x_name_scatter = wx.ComboBox(self.selections3,
                                          choices=['No '
                                                   'Data Loaded'],
                                          pos=(80, 85),
                                          size=(65, 20))
        self.y_name_scatter = wx.ComboBox(self.selections3,
                                          choices=['No '
                                                   'Data Loaded'],
                                          pos=(80, 110),
                                          size=(65, 20))
        self.z_name_scatter = wx.ComboBox(self.selections3,
                                          choices=['No '
                                                   'Data Loaded'],
                                          pos=(80, 135),
                                          size=(65, 20))

        self.x1_name_scatter = wx.ComboBox(self.selections3,
                                           choices=['No '
                                                    'Data Loaded'],
                                           pos=(160, 85),
                                           size=(65, 20))
        self.y1_name_scatter = wx.ComboBox(self.selections3,
                                           choices=['No '
                                                    'Data Loaded'],
                                           pos=(160, 110),
                                           size=(65, 20))
        self.z1_name_scatter = wx.ComboBox(self.selections3,
                                           choices=['No '
                                                    'Data Loaded'],
                                           pos=(160, 135),
                                           size=(65, 20))

        self.x_name_scatter.Bind(wx.EVT_TEXT, self.confirm_scatter)
        self.y_name_scatter.Bind(wx.EVT_TEXT, self.confirm_scatter)
        self.z_name_scatter.Bind(wx.EVT_TEXT, self.confirm_scatter)
        self.x1_name_scatter.Bind(wx.EVT_TEXT, self.confirm_scatter)
        self.y1_name_scatter.Bind(wx.EVT_TEXT, self.confirm_scatter)
        self.z1_name_scatter.Bind(wx.EVT_TEXT, self.confirm_scatter)

        wx.StaticText(self.selections3, label="Size:", pos=(10, 160))
        self.size_slider_scatter = wx.Slider(self.selections3,
                                             1,
                                             25,
                                             0,
                                             100,
                                             pos=(10, 180),
                                             size=(200, -1),
                                             style=wx.SL_HORIZONTAL
                                             | wx.SL_VALUE_LABEL)

        wx.StaticText(self.selections3, label="x limit", pos=(10, 264))
        wx.StaticText(self.selections3, label="y limit", pos=(10, 287))
        wx.StaticText(self.selections3, label="upper", pos=(109, 240))
        wx.StaticText(self.selections3, label="lower", pos=(54, 240))

        self.xLowLim = wx.TextCtrl(self.selections3,
                                   pos=(50, 260),
                                   size=(50, 20))
        self.xUpLim = wx.TextCtrl(self.selections3,
                                  pos=(105, 260),
                                  size=(50, 20))
        self.yLowLim = wx.TextCtrl(self.selections3,
                                   pos=(50, 285),
                                   size=(50, 20))
        self.yUpLim = wx.TextCtrl(self.selections3,
                                  pos=(105, 285),
                                  size=(50, 20))

        self.scatter_log_x = wx.CheckBox(self.selections3,
                                         label="log x",
                                         pos=(10, 320))
        self.scatter_log_x.SetValue(False)

        self.scatter_log_y = wx.CheckBox(self.selections3,
                                         label="log y",
                                         pos=(10, 350))
        self.scatter_log_y.SetValue(False)

        wx.StaticText(self.selections3, label="Label Points", pos=(10, 592))
        self.label_points_scatter = wx.ComboBox(self.selections3,
                                                choices=['No '
                                                         'Data Loaded'],
                                                pos=(80, 590),
                                                size=(152, 20))

        self.confirm_btn_scatter = wx.Button(self.selections3,
                                             label="Generate "
                                             "Graph",
                                             pos=(10, 642),
                                             size=(225, 65))
        self.confirm_btn_scatter.Bind(wx.EVT_BUTTON, self.confirm_scatter)
        self.scatter_button = wx.Button(self.selections3,
                                        label='Generate PDF',
                                        pos=(10, 720),
                                        size=(225, 65))
        self.scatter_button.Bind(wx.EVT_BUTTON, self.on_scatter)
        self.scatter_button.Enable(False)

        # --------------------------------------------------------------------------------------------------------------

        tab_five = TabPanel(notebook)
        tab_five.SetBackgroundColour("White")
        notebook.AddPage(tab_five, "Ternary")
        self.sp5 = wx.SplitterWindow(tab_five)
        self.selections5 = wx.Panel(self.sp5, style=wx.SUNKEN_BORDER)
        self.graphing5 = wx.Panel(self.sp5, style=wx.SUNKEN_BORDER)
        self.selections5.SetBackgroundColour("White")
        self.graphing5.SetBackgroundColour("White")
        self.sp5.SplitVertically(self.selections5, self.graphing5, 250)
        sizer5 = wx.BoxSizer(wx.VERTICAL)
        sizer5.Add(self.sp5, 1, wx.EXPAND)
        tab_five.SetSizerAndFit(sizer5)

        wx.StaticText(self.selections5,
                      label="Select Data Labels:",
                      pos=(70, 10))
        wx.StaticText(self.selections5, label="Colour:", pos=(10, 37))
        wx.StaticText(self.selections5, label="Shape:", pos=(10, 62))
        wx.StaticText(self.selections5, label="top:", pos=(10, 87))
        wx.StaticText(self.selections5, label="left:", pos=(10, 112))
        wx.StaticText(self.selections5, label="right", pos=(10, 137))
        self.color_on_tern = wx.ComboBox(self.selections5,
                                         choices=['No Data Loaded'],
                                         pos=(80, 35),
                                         size=(152, 20))
        self.shape_on_tern = wx.ComboBox(self.selections5,
                                         choices=['No Data Loaded'],
                                         pos=(80, 60),
                                         size=(152, 20))
        self.color_on_tern.Bind(
            wx.EVT_TEXT,
            lambda event: self.pop_shape_frm_plots(event, 'tern_color'))
        self.shape_on_tern.Bind(
            wx.EVT_TEXT,
            lambda event: self.pop_shape_frm_plots(event, 'tern_shape'))

        self.x_name_tern = wx.ComboBox(self.selections5,
                                       choices=['No '
                                                'Data Loaded'],
                                       pos=(80, 85),
                                       size=(65, 20))
        self.y_name_tern = wx.ComboBox(self.selections5,
                                       choices=['No '
                                                'Data Loaded'],
                                       pos=(80, 110),
                                       size=(65, 20))
        self.z_name_tern = wx.ComboBox(self.selections5,
                                       choices=['No '
                                                'Data Loaded'],
                                       pos=(80, 135),
                                       size=(65, 20))

        self.x1_name_tern = wx.ComboBox(self.selections5,
                                        choices=['No '
                                                 'Data Loaded'],
                                        pos=(160, 85),
                                        size=(65, 20))
        self.y1_name_tern = wx.ComboBox(self.selections5,
                                        choices=['No '
                                                 'Data Loaded'],
                                        pos=(160, 110),
                                        size=(65, 20))
        self.z1_name_tern = wx.ComboBox(self.selections5,
                                        choices=['No '
                                                 'Data Loaded'],
                                        pos=(160, 135),
                                        size=(65, 20))

        wx.StaticText(self.selections5, label="size:", pos=(10, 162))

        self.size_name_tern = wx.ComboBox(self.selections5,
                                          choices=['No '
                                                   'Data Loaded'],
                                          pos=(80, 160),
                                          size=(65, 20))

        self.size_name_tern.Bind(wx.EVT_TEXT, self.confirm_scatter)
        self.x_name_tern.Bind(wx.EVT_TEXT, self.confirm_tern)
        self.y_name_tern.Bind(wx.EVT_TEXT, self.confirm_tern)
        self.z_name_tern.Bind(wx.EVT_TEXT, self.confirm_tern)
        self.x1_name_tern.Bind(wx.EVT_TEXT, self.confirm_tern)
        self.y1_name_tern.Bind(wx.EVT_TEXT, self.confirm_tern)
        self.z1_name_tern.Bind(wx.EVT_TEXT, self.confirm_tern)

        wx.StaticText(self.selections5, label="multiply left", pos=(10, 297))
        wx.StaticText(self.selections5, label="multiply right", pos=(10, 320))
        wx.StaticText(self.selections5, label="multiply top", pos=(10, 274))
        self.multiply_left = wx.TextCtrl(self.selections5,
                                         pos=(100, 295),
                                         size=(50, 20))
        self.multiply_right = wx.TextCtrl(self.selections5,
                                          pos=(100, 320),
                                          size=(50, 20))
        self.multiply_top = wx.TextCtrl(self.selections5,
                                        pos=(100, 270),
                                        size=(50, 20))

        wx.StaticText(self.selections5, label="Size:", pos=(10, 190))
        self.size_slider_tern = wx.Slider(self.selections5,
                                          1,
                                          25,
                                          0,
                                          100,
                                          pos=(10, 210),
                                          size=(200, -1),
                                          style=wx.SL_HORIZONTAL
                                          | wx.SL_VALUE_LABEL)

        self.confirm_btn_tern = wx.Button(self.selections5,
                                          label="Generate "
                                          "Graph",
                                          pos=(10, 642),
                                          size=(225, 65))
        self.confirm_btn_tern.Bind(wx.EVT_BUTTON, self.confirm_tern)
        self.tern_button = wx.Button(self.selections5,
                                     label='Generate '
                                     'PDF',
                                     pos=(10, 720),
                                     size=(225, 65))
        self.tern_button.Bind(wx.EVT_BUTTON, self.on_tern)
        self.tern_button.Enable(False)
        #------------------------------------- Tab 6 -------------------
        tab_six = TabPanel(notebook)
        tab_six.SetBackgroundColour("White")
        notebook.AddPage(tab_six, "Colors && Shapes")
        self.sp6 = wx.SplitterWindow(tab_six)
        self.selections6 = wx.Panel(self.sp6, style=wx.SUNKEN_BORDER)
        self.graphing6 = wx.Panel(self.sp6, style=wx.SUNKEN_BORDER)
        self.selections6.SetBackgroundColour("White")
        self.graphing6.SetBackgroundColour("White")
        self.sp6.SplitVertically(self.selections6, self.graphing6, 250)
        sizer6 = wx.BoxSizer(wx.VERTICAL)
        sizer6.Add(self.sp6, 1, wx.EXPAND)
        tab_six.SetSizerAndFit(sizer6)
        wx.StaticText(self.selections6, label="Select column", pos=(10, 12))

        self.colorcombo = wx.ComboBox(self.selections6,
                                      -1,
                                      pos=(110, 10),
                                      size=(110, 20),
                                      style=wx.TE_PROCESS_ENTER)
        self.colorcombo.Bind(wx.EVT_TEXT, self.pop_colorbox)

        self.list_ctrl = wx.ListCtrl(self.selections6,
                                     size=(200, 200),
                                     pos=(10, 40),
                                     style=wx.LC_REPORT | wx.BORDER_SUNKEN)
        self.list_ctrl.InsertColumn(0, 'Item')
        self.list_ctrl.InsertColumn(1, 'Shape')
        wx.StaticText(self.selections6, label="Select color", pos=(10, 253))
        self.pallet = wx.ColourPickerCtrl(self.selections6,
                                          colour='blue',
                                          pos=(140, 250),
                                          size=(100, 25))
        self.pallet.Bind(wx.EVT_COLOURPICKER_CHANGED, self.select_color)

        wx.StaticText(self.selections6, label="Select shape", pos=(10, 287))
        self.shapecombo = wx.ComboBox(self.selections6,
                                      -1,
                                      pos=(140, 285),
                                      size=(100, 25),
                                      value='Circle',
                                      choices=[
                                          'Circle', 'Triangle', 'Octagon',
                                          'Square', 'Pentagon', 'Plus', 'Star',
                                          'Diamond'
                                      ],
                                      style=wx.TE_PROCESS_ENTER)
        self.shapecombo.Bind(wx.EVT_TEXT, self.select_shape)

        #--------------------------------------------------------------------------------------------------------
        self.fig_PCA = plt.figure(figsize=(12, 8))
        self.fig_scatter = plt.figure(figsize=(12, 8))
        self.fig_tern = plt.figure(figsize=(12, 8))

        self.canvas2 = FigCanvas(self.graphing2, -1, self.fig_PCA)
        self.canvas3 = FigCanvas(self.graphing3, -1, self.fig_scatter)
        self.canvas5 = FigCanvas(self.graphing5, -1, self.fig_tern)
        # initialse dictionarys to hold colours and shapes
        self.colordict = {'All': 'b'}
        self.shapedict = {'All': ('Circle', 'o')}
        self.mplshapedict = {
            'Circle': 'o',
            'Triangle': '^',
            'Octagon': '8',
            'Square': 's',
            'Pentagon': 'p',
            'Plus': 'P',
            'Star': '*',
            'Diamond': 'D'
        }

        # --------------------------------------------------------------------------------------------------------------

    def on_combo1(self, event):
        try:
            selections1 = self.data[str(
                self.filter_combo1.GetValue())].unique()
            selections1 = sorted(selections1.tolist())
            self.filter_check1.Clear()
            for item in range(len(selections1)):
                self.filter_check1.AppendItems(selections1[item])
                self.filter_check1.Check(item, True)
        except AttributeError:
            print("INFO: No Data has been imported")
        except TypeError:
            print("INFO: Cannot set this column as a filter")

    def on_combo2(self, event):
        try:
            selections2 = self.data[str(
                self.filter_combo2.GetValue())].unique()
            selections2 = sorted(selections2.tolist())
            self.filter_check2.Clear()
            for item in range(len(selections2)):
                self.filter_check2.AppendItems(selections2[item])
                self.filter_check2.Check(item, True)
        except AttributeError:
            print("INFO: No Data has been imported")
        except TypeError:
            print("INFO: Cannot set this column as a filter")

    def on_combo3(self, event):
        try:
            selections3 = self.data[str(
                self.filter_combo3.GetValue())].unique()
            selections3 = sorted(selections3.tolist())
            self.filter_check3.Clear()
            for item in range(len(selections3)):
                self.filter_check3.AppendItems(selections3[item])
                self.filter_check3.Check(item, True)
        except AttributeError:
            print("INFO: No Data has been imported")
        except TypeError:
            print("INFO: Cannot set this column as a filter")

    def on_filter(self, event):
        self.data_been_filtered = True
        self.dataFiltered = self.data
        for b in range(len((self.combo_boxes))):
            for h, box in enumerate(self.check_boxes):
                if len(box.GetCheckedItems()) != 0:
                    self.dataFiltered = self.dataFiltered[self.dataFiltered[
                        self.combo_boxes[h].GetValue()].isin(
                            box.GetCheckedStrings())]
                    self.dataFiltered = self.dataFiltered.reset_index(
                        drop=True)
                    self.update_grid(self.dataFiltered, self.column_headers)
                    self.dataFiltered = self.dataFiltered.convert_objects()
                    self.column_headers = list(
                        self.dataFiltered.columns.values)
                    is_number = np.vectorize(
                        lambda x: np.issubdtype(x, np.number))
                    is_number = is_number(self.dataFiltered.dtypes)
                    self.non_numeric_cols = [
                        d
                        for (d, remove) in zip(self.column_headers, is_number)
                        if not remove
                    ]
                    self.numeric_cols = [
                        d
                        for (d, remove) in zip(self.column_headers, is_number)
                        if remove
                    ]

                    self.update_combos(False)

    def save_graph(self, filepath, headers, object, type):
        try:
            object.savefig(filepath + "/" + type + "_" + headers[0] + ".pdf",
                           bbox_inches='tight')
            print("INFO: " + type + " Export Complete")
        except:
            print("Error encountered when saving plot.")

    def pop_colorbox(self, event):

        if self.data_been_filtered:
            items = pd.unique(self.dataFiltered[self.colorcombo.GetValue()])
        else:
            items = pd.unique(self.data[self.colorcombo.GetValue()])

        shapes = ["Circle"] * len(items)
        shape_code_list = []
        for s in shapes:
            code = self.mplshapedict.get(s)
            shape_code_list.append(code)

        for item in items:
            if item not in self.shapedict:
                shape_tup = list(zip(shapes, shape_code_list))
                self.shapedict.update(dict(zip(items, shape_tup)))

        for item in items:
            if item not in self.colordict:
                colors = sns.color_palette("hls", len(items))
                self.colordict.update(dict(zip(items, colors)))

        self.list_ctrl.DeleteAllItems()

        for i, item in enumerate(items):
            self.list_ctrl.InsertItem(i, item)
            s = self.shapedict.get(item)
            self.list_ctrl.SetItem(i, 1, s[0])
            c = self.colordict.get(item)
            try:
                self.list_ctrl.SetItemBackgroundColour(
                    i, tuple(255 * x for x in c))
            except TypeError:
                pass

    def pop_shape_frm_plots(self, event, plottype):
        print(plottype)
        if plottype == 'pca_shape':
            if len(self.shape_on_pca.GetValue()) > 0:
                self.colorcombo.SetValue(self.shape_on_pca.GetValue())
        if plottype == 'scatter_shape':
            if len(self.shape_on_scatter.GetValue()) > 0:
                self.colorcombo.SetValue(self.shape_on_scatter.GetValue())
        if plottype == 'tern_shape':
            if len(self.shape_on_tern.GetValue()) > 0:
                self.colorcombo.SetValue(self.shape_on_tern.GetValue())

        if plottype == 'pca_color':
            if len(self.color_on_pca.GetValue()) > 0:
                self.colorcombo.SetValue(self.color_on_pca.GetValue())
        if plottype == 'scatter_color':
            if len(self.color_on_scatter.GetValue()) > 0:
                self.colorcombo.SetValue(self.color_on_scatter.GetValue())
        if plottype == 'tern_color':
            if len(self.color_on_tern.GetValue()) > 0:
                self.colorcombo.SetValue(self.color_on_tern.GetValue())

    def select_color(self, event):

        color = self.pallet.GetColour()
        self.list_ctrl.SetItemBackgroundColour(
            self.list_ctrl.GetFirstSelected(), color)
        item = self.list_ctrl.GetItemText(self.list_ctrl.GetFirstSelected(), 0)
        updated_color = dict(zip([item], [tuple(x / 255 for x in color)]))
        self.colordict.update(updated_color)

    def select_shape(self, event):
        shape = self.shapecombo.GetValue()

        item = self.list_ctrl.GetItemText(self.list_ctrl.GetFirstSelected(), 0)
        self.list_ctrl.SetItem(self.list_ctrl.GetFirstSelected(), 1, shape)
        code = self.mplshapedict.get(shape)
        shape_tup = list(zip((shape, ), code))
        updated_shape = dict(zip((item, ), shape_tup))
        self.shapedict.update(updated_shape)

    def confirm_PCA(self, event):

        self.fig_PCA.clear()
        self.checkedPCAStrings = self.PCA_selection.GetCheckedStrings()
        self.pca_color = self.color_on_pca.GetValue()
        self.pca_shape = self.shape_on_pca.GetValue()

        self.key_headers_PCA = [
            self.x_axis_selection.GetValue(),
            self.y_axis_selection.GetValue(), self.checkedPCAStrings,
            self.pca_color,
            self.size_slider_PCA.GetValue(), self.pca_shape,
            self.label_points_pca.GetValue()
        ]
        if self.data_been_filtered:
            self.fig_PCA = pca.pca_(self.dataFiltered, self.key_headers_PCA,
                                    self.fig_PCA, self.CLR_check.GetValue(),
                                    self.arrows_check.GetValue(),
                                    self.samples_check.GetValue(),
                                    self.colordict, self.shapedict)
        else:

            self.fig_PCA = pca.pca_(self.data, self.key_headers_PCA,
                                    self.fig_PCA, self.CLR_check.GetValue(),
                                    self.arrows_check.GetValue(),
                                    self.samples_check.GetValue(),
                                    self.colordict, self.shapedict)
        self.PCA_plot = self.fig_PCA
        self.canvas2.draw()
        self.confirm_btn_PCA.SetLabelText("Update Graph")
        self.PCA_button.Enable(True)

    def confirm_scatter(self, event):
        self.fig_scatter.clear()
        self.scatter_color = self.color_on_scatter.GetValue()
        self.scatter_shape = self.shape_on_scatter.GetValue()

        size = self.size_slider_scatter.GetValue()
        self.key_headers_scatter = [
            self.scatter_color,
            self.x_name_scatter.GetValue(),
            self.y_name_scatter.GetValue(),
            self.z_name_scatter.GetValue(),
            self.x1_name_scatter.GetValue(),
            self.y1_name_scatter.GetValue(),
            self.z1_name_scatter.GetValue(), size, self.scatter_shape,
            self.label_points_scatter.GetValue()
        ]
        limits = [
            self.xLowLim.GetValue(),
            self.xUpLim.GetValue(),
            self.yLowLim.GetValue(),
            self.yUpLim.GetValue()
        ]
        log_scales = [
            self.scatter_log_x.GetValue(),
            self.scatter_log_y.GetValue()
        ]
        if self.data_been_filtered:
            if not len(self.x_name_scatter.GetValue()) == 0 and not len(
                    self.y_name_scatter.GetValue()) == 0:
                self.fig_scatter = pca.blank_scatter_plot(
                    self.dataFiltered, self.key_headers_scatter, limits,
                    self.fig_scatter, log_scales, self.colordict,
                    self.shapedict)
        else:
            if not len(self.x_name_scatter.GetValue()) == 0 and not len(
                    self.y_name_scatter.GetValue()) == 0:
                self.fig_scatter = pca.blank_scatter_plot(
                    self.data, self.key_headers_scatter, limits,
                    self.fig_scatter, log_scales, self.colordict,
                    self.shapedict)
        self.scatter_plot = self.fig_scatter
        self.canvas3.draw()
        self.confirm_btn_scatter.SetLabelText("Update Graph")
        self.scatter_button.Enable(True)

    def confirm_tern(self, event):

        self.fig_tern.clear()
        self.tern_color = self.color_on_tern.GetValue()
        self.tern_shape = self.shape_on_tern.GetValue()

        x = self.x_name_tern.GetValue()
        y = self.y_name_tern.GetValue()
        z = self.z_name_tern.GetValue()
        x1 = self.x1_name_tern.GetValue()
        y1 = self.y1_name_tern.GetValue()
        z1 = self.z1_name_tern.GetValue()
        self.key_headers_tern = [
            self.tern_color, x, y, z, x1, y1, z1,
            self.size_name_tern.GetValue(), self.tern_shape
        ]
        self.scalers = [
            self.multiply_left.GetValue(),
            self.multiply_right.GetValue(),
            self.multiply_top.GetValue()
        ]

        if self.data_been_filtered:
            if len(x) > 0 and len(y) > 0 and len(z) > 0:
                self.fig_tern = pca.ternary_(self.dataFiltered,
                                             self.key_headers_tern,
                                             self.fig_tern, self.scalers,
                                             self.size_slider_tern.GetValue(),
                                             self.colordict, self.shapedict)
        else:
            if len(x) > 0 and len(y) > 0 and len(z) > 0:
                self.fig_tern = pca.ternary_(self.data, self.key_headers_tern,
                                             self.fig_tern, self.scalers,
                                             self.size_slider_tern.GetValue(),
                                             self.colordict, self.shapedict)
        self.tern = self.fig_tern
        self.canvas5.Destroy()
        self.canvas5 = FigCanvas(self.graphing5, -1, self.tern)
        self.confirm_btn_tern.SetLabelText("Update Graph")
        self.tern_button.Enable(True)

    def update_grid(self, data, col_headers):
        try:
            self.grid.ClearGrid()
            current_row, new_row = (self.grid.GetNumberRows(), len(data))
            current_col, new_col = (self.grid.GetNumberCols(),
                                    len(col_headers))
            if current_row > new_row:
                self.grid.DeleteRows(0, current_row - new_row, True)
            elif current_row < new_row:
                self.grid.AppendRows(new_row - current_row)
            if current_col > new_col:
                self.grid.DeleteCols(0, current_col - new_col, True)
            elif current_col < new_col:
                self.grid.AppendCols(new_col - current_col)
            for c in range(len(col_headers)):
                self.grid.SetColLabelValue(c, col_headers[c])
            for row in range(len(data)):
                for col in range(len(col_headers)):
                    self.grid.SetCellValue(row, col, str(data.iloc[row, col]))
            self.grid = self.grid
        except:
            print("INFO: Error encountered when updating grid.")

    def on_open(self, event):
        dlg = wx.FileDialog(self,
                            message="Load Database",
                            defaultDir=os.getcwd(),
                            defaultFile="")
        if dlg.ShowModal() == wx.ID_OK:
            self.filepath = dlg.GetPath()
            self.currentdir = os.path.dirname(dlg.GetPath()) + '\\'
            print("INFO: Database Selected: " + self.filepath)
            self.data_source = str(self.filepath)

    #try:
        self.data = dc.standard_cleanse(self.filepath, None)
        self.data = self.data.reset_index(drop=True)
        self.dataFiltered = []
        self.column_headers = list(self.data.columns.values)
        is_number = np.vectorize(lambda x: np.issubdtype(x, np.number))
        is_number = is_number(self.data.dtypes)
        self.non_numeric_cols = [
            d for (d, remove) in zip(self.column_headers, is_number)
            if not remove
        ]
        self.numeric_cols = [
            d for (d, remove) in zip(self.column_headers, is_number) if remove
        ]

        self.update_grid(self.data, self.column_headers)
        self.update_combos(True)
        print("INFO: Data Imported Successfully")

    def update_combos(self, updatefilters):
        selected_PCA_items = self.PCA_selection.GetCheckedStrings()
        color_on_pca = self.color_on_pca.GetValue()
        shape_on_pca = self.shape_on_pca.GetValue()
        color_on_scatter = self.color_on_scatter.GetValue()
        label_points_pca = self.label_points_pca.GetValue()
        label_points_scatter = self.label_points_scatter.GetValue()
        shape_on_scatter = self.shape_on_scatter.GetValue()
        color_on_tern = self.color_on_tern.GetValue()
        shape_on_tern = self.shape_on_tern.GetValue()
        x_name_scatter = self.x_name_scatter.GetValue()
        y_name_scatter = self.y_name_scatter.GetValue()
        z_name_scatter = self.z_name_scatter.GetValue()
        x1_name_scatter = self.x1_name_scatter.GetValue()
        y1_name_scatter = self.y1_name_scatter.GetValue()
        z1_name_scatter = self.z1_name_scatter.GetValue()
        x_name_tern = self.x_name_tern.GetValue()
        y_name_tern = self.y_name_tern.GetValue()
        z_name_tern = self.z_name_tern.GetValue()
        x1_name_tern = self.x1_name_tern.GetValue()
        y1_name_tern = self.y1_name_tern.GetValue()
        z1_name_tern = self.z1_name_tern.GetValue()
        filter_combo1 = self.filter_combo1.GetValue()
        filter_combo2 = self.filter_combo2.GetValue()
        filter_combo3 = self.filter_combo3.GetValue()

        self.PCA_selection.Clear()
        self.PCA_selection.SetItems(self.numeric_cols)
        items = self.PCA_selection.GetItems()

        if updatefilters:
            self.filter_combo1.Clear()
            self.filter_combo2.Clear()
            self.filter_combo3.Clear()
            self.filter_check1.Clear()
            self.filter_check2.Clear()
            self.filter_check3.Clear()

        self.color_on_pca.Clear()
        self.shape_on_pca.Clear()
        self.color_on_scatter.Clear()
        self.label_points_pca.Clear()
        self.label_points_scatter.Clear()
        self.shape_on_scatter.Clear()
        self.color_on_tern.Clear()
        self.shape_on_tern.Clear()

        self.x_name_scatter.Clear()
        self.y_name_scatter.Clear()
        self.z_name_scatter.Clear()
        self.x1_name_scatter.Clear()
        self.y1_name_scatter.Clear()
        self.z1_name_scatter.Clear()

        self.x_name_tern.Clear()
        self.y_name_tern.Clear()
        self.z_name_tern.Clear()
        self.x1_name_tern.Clear()
        self.y1_name_tern.Clear()
        self.z1_name_tern.Clear()
        self.size_name_tern.Clear()

        self.color_on_pca.AppendItems(self.non_numeric_cols)
        self.shape_on_pca.AppendItems(self.non_numeric_cols)
        self.color_on_scatter.AppendItems(self.non_numeric_cols)
        self.label_points_pca.AppendItems(self.column_headers)
        self.label_points_scatter.AppendItems(self.column_headers)
        self.shape_on_scatter.AppendItems(self.non_numeric_cols)
        self.color_on_tern.AppendItems(self.non_numeric_cols)
        self.shape_on_tern.AppendItems(self.non_numeric_cols)
        self.x_name_scatter.AppendItems(self.numeric_cols)
        self.y_name_scatter.AppendItems(self.numeric_cols)
        self.z_name_scatter.AppendItems(self.numeric_cols)
        self.x1_name_scatter.AppendItems(self.numeric_cols)
        self.y1_name_scatter.AppendItems(self.numeric_cols)
        self.z1_name_scatter.AppendItems(self.numeric_cols)
        self.x_name_tern.AppendItems(self.numeric_cols)
        self.y_name_tern.AppendItems(self.numeric_cols)
        self.z_name_tern.AppendItems(self.numeric_cols)
        self.x1_name_tern.AppendItems(self.numeric_cols)
        self.y1_name_tern.AppendItems(self.numeric_cols)
        self.z1_name_tern.AppendItems(self.numeric_cols)
        self.size_name_tern.AppendItems(self.numeric_cols)
        self.filter_combo1.AppendItems(self.non_numeric_cols)
        self.filter_combo2.AppendItems(self.non_numeric_cols)
        self.filter_combo3.AppendItems(self.non_numeric_cols)
        self.colorcombo.AppendItems(self.non_numeric_cols)
        if len(selected_PCA_items) > 0:
            self.PCA_selection.SetCheckedStrings(selected_PCA_items)
        else:
            self.PCA_selection.SetCheckedStrings(items)

        self.color_on_pca.SetValue(color_on_pca)
        self.shape_on_pca.SetValue(shape_on_pca)
        self.color_on_scatter.SetValue(color_on_scatter)
        self.label_points_pca.SetValue(label_points_pca)
        self.label_points_scatter.SetValue(label_points_scatter)
        self.shape_on_scatter.SetValue(shape_on_scatter)
        self.color_on_tern.SetValue(color_on_tern)
        self.shape_on_tern.SetValue(shape_on_tern)
        self.x_name_scatter.SetValue(x_name_scatter)
        self.y_name_scatter.SetValue(y_name_scatter)
        self.z_name_scatter.SetValue(z_name_scatter)
        self.x1_name_scatter.SetValue(x1_name_scatter)
        self.y1_name_scatter.SetValue(y1_name_scatter)
        self.z1_name_scatter.SetValue(z1_name_scatter)
        self.x_name_tern.SetValue(x_name_tern)
        self.y_name_tern.SetValue(y_name_tern)
        self.z_name_tern.SetValue(z_name_tern)
        self.x1_name_tern.SetValue(x1_name_tern)
        self.y1_name_tern.SetValue(y1_name_tern)
        self.z1_name_tern.SetValue(z1_name_tern)
        self.filter_combo1.SetValue(filter_combo1)
        self.filter_combo2.SetValue(filter_combo2)
        self.filter_combo3.SetValue(filter_combo3)

    def on_PCA(self, event):
        print("INFO: Exporting SandClass")
        self.save_graph(self.currentdir, self.key_headers_PCA, self.PCA_plot,
                        "PCA")

    def on_scatter(self, event):
        print("INFO: Exporting Pettijohn")
        self.save_graph(self.currentdir, self.key_headers_scatter,
                        self.scatter_plot, "Scatter")

    def on_tern(self, event):
        print("INFO: Exporting Ternary")
        self.save_graph(self.currentdir, self.key_headers_tern, self.tern,
                        "Ternary")

    def on_exit(self, event):
        print('INFO: Closing')
        self.Destroy()
Exemplo n.º 3
0
class MainFrame(wx.Frame):
    """Contains the main GUI and button boxes"""  

    def __init__(self, parent, config,Screens,scale_w,scale_h, winHack, img_scale):
        displaysize = wx.GetDisplaySize()

        w = displaysize[0]
        h = displaysize[1]
        self.gui_width = (w*scale_w)/Screens
        self.gui_height = (h*scale_h)


        #print("Scaled GUI width", self.gui_width, "and height", self.gui_height)
        if self.gui_width<600 or self.gui_height<500:
                print("Your screen width", w, "and height", h)
                print("Scaled GUI width", self.gui_width, "and height", self.gui_height)
                print("Please adjust scale_h and scale_w, or get a bigger screen!")
        
        self.size=displaysize
        
        wx.Frame.__init__(self, None, title="DeepLabCut2.0 - Refinement GUI", size=(self.gui_width*winHack, self.gui_height*winHack), style= wx.DEFAULT_FRAME_STYLE)

        

# Add SplitterWindow panels top for figure and bottom for buttons
        self.split_win = wx.SplitterWindow(self)
        # self.top_split = wx.Panel(self.split_win, style=wx.SUNKEN_BORDER)
        self.top_split = MatplotPanel(self.split_win,config) # This call/link the MatplotPanel and MainFrame classes which replaces the above line
        self.bottom_split = wx.Panel(self.split_win, style=wx.SUNKEN_BORDER)
        self.split_win.SplitHorizontally(self.top_split, self.bottom_split, self.gui_height*.9)

        self.top_split.SetBackgroundColour((255, 255, 255))
        self.bottom_split.SetBackgroundColour((74, 34, 101))

        self.statusbar = self.CreateStatusBar()
        self.statusbar.SetStatusText("")
# Add Buttons to the bottom_split window and bind them to plot functions
        self.Button1 = wx.Button(self.bottom_split, -1, "Load Labels", size=(150, 40), pos=(self.gui_width*.1, self.gui_height*.01))
        self.Button1.Bind(wx.EVT_BUTTON, self.browseDir)

        self.Button5 = wx.Button(self.bottom_split, -1, "Help", size=(60, 40), pos=(self.gui_width*.2, self.gui_height*.01))
        self.Button5.Bind(wx.EVT_BUTTON, self.help)
        self.Button5.Enable(True)

        self.Button3 = wx.Button(self.bottom_split, -1, "Previous Image", size=(150, 40), pos=(self.gui_width*.35, self.gui_height*.01))
        self.Button3.Bind(wx.EVT_BUTTON, self.prevImage)
        self.Button3.Enable(False)

        self.Button2 = wx.Button(self.bottom_split, -1, "Next Image", size=(130, 40), pos=(self.gui_width*.45, self.gui_height*.01))
        self.Button2.Bind(wx.EVT_BUTTON, self.nextImage)
        self.Button2.Enable(False)

        self.Button4 = wx.Button(self.bottom_split, -1, "Save", size=(100, 40), pos=(self.gui_width*.6, self.gui_height*.01))
        self.Button4.Bind(wx.EVT_BUTTON, self.save)
        self.Button4.Enable(False)

        self.close = wx.Button(self.bottom_split, -1, "Quit", size=(100, 40), pos=(self.gui_width*.69, self.gui_height*.01))
        self.close.Bind(wx.EVT_BUTTON,self.quitButton)
        self.close.Enable(True)

        self.adjustLabelCheck = wx.CheckBox(self.top_split, label = 'Adjust original labels?',pos = (self.gui_width*.1, self.gui_height*.85))
        self.adjustLabelCheck.Bind(wx.EVT_CHECKBOX,self.adjustLabel)
        
        self.Button5 = wx.Button(self.top_split,-1,"Zoom", size=(60,30),pos=(self.gui_width*.6,self.gui_height*.85))
        self.Button5.Bind(wx.EVT_BUTTON,self.zoom)
        
        self.Button6 = wx.Button(self.top_split,-1,"Pan", size=(60,30),pos=(self.gui_width*.65,self.gui_height*.85))
        self.Button6.Bind(wx.EVT_BUTTON,self.pan)
        
        self.Button7 = wx.Button(self.top_split,-1,"Home", size=(60,30),pos=(self.gui_width*.7,self.gui_height*.85))
        self.Button7.Bind(wx.EVT_BUTTON,self.home)
         
        self.Bind(wx.EVT_CLOSE,self.closewindow)

        self.currentDirectory = os.getcwd()
        self.index = []
        self.iter = []
        self.threshold = []
        self.file = 0
        with open(str(config), 'r') as ymlfile:
            cfg = yaml.load(ymlfile)
        self.humanscorer = cfg['scorer']
        self.move2corner = cfg['move2corner']
        self.center = cfg['corner2move2']
        self.colormap = plt.get_cmap(cfg['colormap'])
        self.markerSize = cfg['dotsize']
        self.adjust_original_labels = False
        self.alpha = cfg['alphavalue']
        self.iterationindex = cfg['iteration']
        self.project_path=cfg['project_path']

        imgW = self.gui_width*img_scale #was 12 inches (perhaps add dpi!)
        imgH = self.gui_height*img_scale    #was 7 inches 

        self.img_size = (imgW, imgH)  # width, height in inches.
        
# ###########################################################################
# functions for button responses
# ###########################################################################
    def closewindow(self, event):
        self.Destroy()

    def adjustLabel(self, event):

      self.chk = event.GetEventObject()
      if self.chk.GetValue() == True:
          self.adjust_original_labels = True
      else:
          self.adjust_original_labels = False
        
    def zoom(self,event):
        self.statusbar.SetStatusText("Zoom")
        self.toolbar.zoom()
        self.Refresh(eraseBackground=True)
        
    def home(self,event):
        self.statusbar.SetStatusText("Home")
        self.toolbar.home()
        self.Refresh(eraseBackground=True)
         
    def pan(self,event):
        self.statusbar.SetStatusText("Pan")
        self.toolbar.pan()
        self.Refresh(eraseBackground=True)
    
    def OnSliderScroll(self, event):
        """
        Adjust marker size for plotting the annotations
        """
        self.drs = []
        self.updatedCoords = []
        plt.close(self.fig1)
        self.fig1, (self.ax1f1) = plt.subplots(figsize=self.img_size,facecolor = "None")
        self.markerSize = (self.slider.GetValue())
        imagename1 = os.path.join(self.project_path,self.index[self.iter])
        im = PIL.Image.open(imagename1)
        im_axis = self.ax1f1.imshow(im,self.colormap)
        if self.adjust_original_labels == False:
            self.ax1f1.set_title(str(str(self.iter)+"/"+str(len(self.index)-1) +" "+ str(Path(self.index[self.iter]).stem) + " "+ " Threshold chosen is: " + str("{0:.2f}".format(self.threshold))))
        else:
            self.ax1f1.set_title(str(str(self.iter)+"/"+str(len(self.index)-1) +" "+ str(Path(self.index[self.iter]).stem)))
        self.canvas = FigureCanvas(self.top_split, -1, self.fig1)
        MainFrame.plot(self,im,im_axis)
        self.toolbar = NavigationToolbar(self.canvas)
        MainFrame.confirm(self)


    def browseDir(self, event):
        """
        Show the DirDialog and ask the user to change the directory where machine labels are stored
        """

        self.adjustLabelCheck.Enable(False)

        if self.adjust_original_labels == True:
            dlg = wx.FileDialog(self, "Choose the labeled dataset file(CollectedData_*.h5 file)", "", "", "All CollectedData_*.h5 files(CollectedData_*.h5)|CollectedData_*.h5", wx.FD_OPEN | wx.FD_FILE_MUST_EXIST)
        else:
            fname = str('machinelabels-iter'+str(self.iterationindex)+'.h5')
            dlg = wx.FileDialog(self, "Choose the machinelabels file for current iteration.",
                                "", "", wildcard=fname,
                                style=wx.FD_OPEN | wx.FD_FILE_MUST_EXIST)

        if dlg.ShowModal() == wx.ID_OK:
            self.data_file = dlg.GetPath()
            self.dir = str(Path(self.data_file).parents[0])
            print(self.dir)
            self.fileName = str(Path(self.data_file).stem)
            self.Button1.Enable(False)
            self.Button2.Enable(True)
            self.Button4.Enable(True)
            self.Button5.Enable(True)
            self.close.Enable(True)
        else:
            dlg.Destroy()
            self.Close(True)
        dlg.Destroy()

        self.fig1, (self.ax1f1) = plt.subplots(figsize=self.img_size,facecolor = "None")
        try:
            self.dataname = str(self.data_file)
        except:
            print("No machinelabels file found!")
            self.Destroy()
        self.iter = 0

        if os.path.isfile(self.dataname):
            self.Dataframe = pd.read_hdf(self.dataname,'df_with_missing')
            self.scorer = self.Dataframe.columns.get_level_values(0)[0]
            bodyParts = self.Dataframe.columns.get_level_values(1)
            _, idx = np.unique(bodyParts, return_index=True)
            self.bodyparts2plot =  bodyParts[np.sort(idx)]
            self.num_joints = len(self.bodyparts2plot)
            self.index = list(self.Dataframe.iloc[:,0].index)
            self.drs = []
            self.updatedCoords = []

            # Reading images

            imagename1 = os.path.join(self.project_path,self.index[self.iter])
            im = PIL.Image.open(imagename1)
            # Plotting
            im_axis = self.ax1f1.imshow(im,self.colormap)

            if self.file == 0:
                self.checkBox = wx.CheckBox(self.top_split, label = 'Adjust marker size.',pos = (self.gui_width*.43, self.gui_height*.85))
                self.checkBox.Bind(wx.EVT_CHECKBOX,self.onChecked)
                self.slider = wx.Slider(self.top_split, -1, self.markerSize, 0, 20,size=(200, -1),  pos=(self.gui_width*.43, self.gui_height*.8),style=wx.SL_HORIZONTAL | wx.SL_AUTOTICKS | wx.SL_LABELS )
                self.slider.Bind(wx.EVT_SLIDER, self.OnSliderScroll)
                self.slider.Enable(False)
            
            self.canvas = FigureCanvas(self.top_split, -1, self.fig1)
            self.colorparams = list(range(0,(self.num_joints+1)))
            MainFrame.plot(self,im,im_axis)
            self.toolbar = NavigationToolbar(self.canvas)
            
            if self.adjust_original_labels == False:

                instruction = wx.MessageBox('1. Enter the likelihood threshold. \n\n2. Each prediction will be shown with a unique color. \n All the data points above the threshold will be marked as circle filled with a unique color. All the data points below the threshold will be marked with a hollow circle. \n\n3. Enable the checkbox to adjust the marker size. \n\n4.  Hover your mouse over data points to see the labels and their likelihood. \n\n5. Left click and drag to move the data points. \n\n6. Right click on any data point to remove it. Be careful, you cannot undo this step. \n Click once on the zoom button to zoom-in the image.The cursor will become cross, click and drag over a point to zoom in. \n Click on the zoom button again to disable the zooming function and recover the cursor. \n Use pan button to pan across the image while zoomed in. Use home button to go back to the full;default view. \n\n7. When finished click \'Save\' to save all the changes. \n\n8. Click OK to continue', 'User instructions', wx.OK | wx.ICON_INFORMATION)

                if instruction == 4 :
                    """
                    If ok is selected then the image is updated with the thresholded value of the likelihood
                    """
                    textBox = wx.TextEntryDialog(self, "Select the likelihood threshold",caption = "Enter the threshold",value="0.1")
                    textBox.ShowModal()
                    self.threshold = float(textBox.GetValue())
                    textBox.Destroy()
                    self.drs = []
                    self.updatedCoords = []
                    plt.close(self.fig1)
                    self.canvas.Destroy()
                    self.fig1, (self.ax1f1) = plt.subplots(figsize=(12, 7.8),facecolor = "None")
                    #imagename1 = os.path.join(self.dir,self.index[self.iter])
                    imagename1 = os.path.join(self.project_path,self.index[self.iter])
                    im = PIL.Image.open(imagename1)
                    im_axis = self.ax1f1.imshow(im,self.colormap)
                    self.ax1f1.set_title(str(str(self.iter)+"/"+str(len(self.index)-1) +" "+ str(Path(self.index[self.iter]).stem) + " "+ " Threshold chosen is: " + str("{0:.2f}".format(self.threshold))))
                    self.canvas = FigureCanvas(self.top_split, -1, self.fig1)
                    MainFrame.plot(self,im,im_axis)
                    MainFrame.confirm(self)
                    self.toolbar = NavigationToolbar(self.canvas)
                else:
                    self.threshold = 0.1

                self.ax1f1.set_title(str(str(self.iter)+"/"+str(len(self.index)-1) +" "+ str(Path(self.index[self.iter]).stem) + " "+ " Threshold chosen is: " + str("{0:.2f}".format(self.threshold))))
            else:
                self.ax1f1.set_title(str(str(self.iter)+"/"+str(len(self.index)-1) +" "+ str(Path(self.index[self.iter]).stem)))


        else:
            msg = wx.MessageBox('No Machinelabels file found! Want to retry?', 'Error!', wx.YES_NO | wx.ICON_WARNING)
            if msg == 2:
                self.Button1.Enable(True)
                self.Button2.Enable(False)
                self.Button4.Enable(False)
            else:
                self.Destroy()

    def nextImage(self, event):
        """
        Reads the next image and enables the user to move the annotations
        """
        MainFrame.confirm(self)
        self.canvas.Destroy()
        plt.close(self.fig1)
        self.Button3.Enable(True)
        self.checkBox.Enable(False)
        self.slider.Enable(False)
        self.iter = self.iter + 1
        self.fig1, (self.ax1f1) = plt.subplots(figsize=self.img_size,facecolor="None")

        # Checks for the last image and disables the Next button
        if len(self.index) - self.iter == 1:
            self.Button2.Enable(False)

        if len(self.index) > self.iter:
            self.updatedCoords = []
            #read the image
            #imagename1 = os.path.join(self.dir,self.index[self.iter])
            imagename1 = os.path.join(self.project_path,self.index[self.iter])
            im = PIL.Image.open(imagename1)

            #Plotting
            im_axis = self.ax1f1.imshow(im,self.colormap)

            self.ax1f1.imshow(im)
            if self.adjust_original_labels == True:
                self.ax1f1.set_title(str(str(self.iter)+"/"+str(len(self.index)-1) +" "+ str(Path(self.index[self.iter]).stem)))
            else:
                self.ax1f1.set_title(str(str(self.iter)+"/"+str(len(self.index)-1) +" "+ str(Path(self.index[self.iter]).stem) + " "+ " Threshold chosen is: " + str("{0:.2f}".format(self.threshold))))


            self.canvas = FigureCanvas(self.top_split, -1, self.fig1)
            if np.max(im) == 0:
                msg = wx.MessageBox('Invalid image. Click Yes to remove', 'Error!', wx.YES_NO | wx.ICON_WARNING)
                if msg == 2:
                    self.Dataframe = self.Dataframe.drop(self.index[self.iter])
                    self.index = list(self.Dataframe.iloc[:,0].index)
                self.iter = self.iter - 1
                plt.close(self.fig1)

                #imagename1 = os.path.join(self.dir,self.index[self.iter])
#                im=io.imread(imagename1)
                imagename1 = os.path.join(self.project_path,self.index[self.iter])
                im = PIL.Image.open(imagename1)

                #Plotting
                im_axis = self.ax1f1.imshow(im,self.colormap)

                self.ax1f1.imshow(im)
                if self.adjust_original_labels == True:
                    self.ax1f1.set_title(str(str(self.iter)+"/"+str(len(self.index)-1) +" "+ str(Path(self.index[self.iter]).stem)))
                else:
                    self.ax1f1.set_title(str(str(self.iter)+"/"+str(len(self.index)-1) +" "+ self.index[self.iter] + " "+ " Threshold chosen is: " + str("{0:.2f}".format(self.threshold))))
                self.canvas = FigureCanvas(self.top_split, -1, self.fig1)
                print(self.iter)
            MainFrame.plot(self,im,im_axis)
            self.toolbar = NavigationToolbar(self.canvas)
        else:
            self.Button2.Enable(False)

    def prevImage(self, event):
        """
        Checks the previous Image and enables user to move the annotations.
        """

        MainFrame.confirm(self)
        self.canvas.Destroy()
        self.Button2.Enable(True)
        self.checkBox.Enable(False)
#        self.cb.SetValue(False)
        self.slider.Enable(False)
        plt.close(self.fig1)
        self.fig1, (self.ax1f1) = plt.subplots(figsize=self.img_size,facecolor="None")
        self.iter = self.iter - 1

        # Checks for the first image and disables the Previous button
        if self.iter == 0:
            self.Button3.Enable(False)

        if self.iter >= 0:
            self.updatedCoords = []
            self.drs = []
            # Reading Image
#            imagename1 = os.path.join(self.dir,"file%04d.png" % self.index[self.iter])
            #imagename1 = os.path.join(self.dir,self.index[self.iter])
#            im=io.imread(imagename1)
            imagename1 = os.path.join(self.project_path,self.index[self.iter])
            im = PIL.Image.open(imagename1)

            # Plotting
            im_axis = self.ax1f1.imshow(im,self.colormap)
#            plt.tight_layout(rect=[0, 0.1, 1, 0.95])
            if self.adjust_original_labels == True:
                self.ax1f1.set_title(str(str(self.iter)+"/"+str(len(self.index)-1) +" "+ str(Path(self.index[self.iter]).stem)))
            else:
                self.ax1f1.set_title(str(str(self.iter)+"/"+str(len(self.index)-1) +" "+ str(Path(self.index[self.iter]).stem) + " "+ " Threshold chosen is: " + str("{0:.2f}".format(self.threshold))))
            self.canvas = FigureCanvas(self.top_split, -1, self.fig1)
            MainFrame.plot(self,im,im_axis)
            self.toolbar = NavigationToolbar(self.canvas)
        else:
            self.Button3.Enable(False)


    def quitButton(self, event):
        """
        Quits the GUI
        """
        plt.close('all')
        print("Closing... you did not hit save!")
        self.Destroy()

    def help(self,event):
        """
        Opens Instructions
        """
        if self.adjust_original_labels == True:
            wx.MessageBox('1. Each label will be shown with a unique color. \n\n2. Enable the checkbox to adjust the marker size (you will not be able to zoom/pan/home until the next frame). \n\n3.  Hover your mouse over data points to see the labels. \n\n4. LEFT click+ drag to move the data points. \n\n5. RIGHT click on any data point to remove it. Be careful, you cannot undo this step! \n\n 6. Click once on the zoom button to zoom-in the image. The cursor will become cross, click and drag over a point to zoom in. \n Click on the zoom button again to disable the zooming function and recover the cursor. \n\n 7. Use pan button to pan across the image while zoomed in. Use home button to go back to the full default view. \n\n8. When finished click \'Save\' to save all the changes. \n\n9. Click OK to continue', 'Instructions to use!', wx.OK | wx.ICON_INFORMATION)
        else:
            wx.MessageBox('1. Enter the likelihood threshold. \n\n2. All the data points above the threshold will be marked as circle filled with a unique color. All the data points below the threshold will be marked with a hollow circle. \n\n3. Enable the checkbox to adjust the marker size (you will not be able to zoom/pan/home until the next frame). \n\n4. Hover your mouse over data points to see the labels and their likelihood. \n\n5. LEFT click+drag to move the data points. \n\n6. RIGHT click on any data point to remove it. Be careful, you cannot undo this step! \n Click once on the zoom button to zoom-in the image. The cursor will become cross, click and drag over a point to zoom in. \n Click on the zoom button again to disable the zooming function and recover the cursor. \n Use pan button to pan across the image while zoomed in. Use home button to go back to the full default view. \n\n7. When finished click \'Save\' to save all the changes. \n\n8. Click OK to continue', 'User instructions', wx.OK | wx.ICON_INFORMATION)

    def onChecked(self, event):
      MainFrame.confirm(self)
      self.cb = event.GetEventObject()
      if self.cb.GetValue() == True:
          self.slider.Enable(True)
      else:
          self.slider.Enable(False)
    
    def check_labels(self):
        print("Checking labels if they are outside the image")
        for i in self.Dataframe.index:
            image_name = os.path.join(self.project_path,i)
            im = PIL.Image.open(image_name)
            width, height = im.size
            for bpindex, bp in enumerate(self.bodyparts2plot):
                testCondition = self.Dataframe.loc[i,(self.scorer,bp,'x')] > width or self.Dataframe.loc[i,(self.scorer,bp,'x')] < 0 or self.Dataframe.loc[i,(self.scorer,bp,'y')] > height or self.Dataframe.loc[i,(self.scorer,bp,'y')] <0
                if testCondition:
                    print("Found %s outside the image %s.Setting it to NaN" %(bp,i))
                    self.Dataframe.loc[i,(self.scorer,bp,'x')] = np.nan
                    self.Dataframe.loc[i,(self.scorer,bp,'y')] = np.nan
        return(self.Dataframe)
        
    def save(self, event):

        MainFrame.confirm(self)
        plt.close(self.fig1)
        
        if self.adjust_original_labels == True:
            self.Dataframe = MainFrame.check_labels(self)
            self.Dataframe.to_hdf(os.path.join(self.dir,'CollectedData_'+self.humanscorer+'.h5'), key='df_with_missing', mode='w')
            self.Dataframe.to_csv(os.path.join(self.dir,'CollectedData_'+self.humanscorer+".csv"))
        else:
            self.Dataframe = MainFrame.check_labels(self)
            self.Dataframe.columns.set_levels([self.scorer.replace(self.scorer,self.humanscorer)],level=0,inplace=True)
            self.Dataframe = self.Dataframe.drop('likelihood',axis=1,level=2)

        if Path(self.dir,'CollectedData_'+self.humanscorer+'.h5').is_file():
            print("A training dataset file is already found for this video. The refined machine labels are merged to this data!")
            DataU1 = pd.read_hdf(os.path.join(self.dir,'CollectedData_'+self.humanscorer+'.h5'), 'df_with_missing')
            #combine datasets Original Col. + corrected machinefiles:
            DataCombined = pd.concat([self.Dataframe,DataU1])
            # Now drop redundant ones keeping the first one [this will make sure that the refined machine file gets preference]
            DataCombined = DataCombined[~DataCombined.index.duplicated(keep='first')]
            '''
            if len(self.droppedframes)>0: #i.e. frames were dropped/corrupt. also remove them from original file (if they exist!)
                for fn in self.droppedframes:
                    try:
                        DataCombined.drop(fn,inplace=True)
                    except KeyError:
                        pass
            '''
            DataCombined.to_hdf(os.path.join(self.dir,'CollectedData_'+ self.humanscorer +'.h5'), key='df_with_missing', mode='w')
            DataCombined.to_csv(os.path.join(self.dir,'CollectedData_'+ self.humanscorer +'.csv'))
        else:
            self.Dataframe.to_hdf(os.path.join(self.dir,'CollectedData_'+ self.humanscorer+'.h5'), key='df_with_missing', mode='w')
            self.Dataframe.to_csv(os.path.join(self.dir,'CollectedData_'+ self.humanscorer +'.csv'))
            self.Button2.Enable(False)
            self.Button3.Enable(False)
            self.slider.Enable(False)
            self.checkBox.Enable(False)

        nextFilemsg = wx.MessageBox('File saved. Do you want to refine another file?', 'Repeat?', wx.YES_NO | wx.ICON_INFORMATION)
        if nextFilemsg == 2:
            self.file = 1
            self.canvas.Destroy()
            plt.close(self.fig1)
            self.Button1.Enable(True)
            self.slider.Enable(False)
            self.checkBox.Enable(False)
            MainFrame.browseDir(self, event)
        else:
            print("Closing... The refined labels are stored in a subdirectory under labeled-data. Use the function 'merge_datasets' to augment the training dataset, and then re-train a network using create_training_dataset followed by train_network!")
            self.Destroy()

# ###########################################################################
# Other functions
# ###########################################################################
    def confirm(self):
        """
        Updates the dataframe for the current image with the new datapoints
        """
        plt.close(self.fig1)
        for bpindex, bp in enumerate(self.bodyparts2plot):
            if self.updatedCoords[bpindex]:
                self.Dataframe.loc[self.Dataframe.index[self.iter],(self.scorer,bp,'x')] = self.updatedCoords[bpindex][-1][0]
                self.Dataframe.loc[self.Dataframe.index[self.iter],(self.scorer,bp,'y')] = self.updatedCoords[bpindex][-1][1]

    def plot(self,im,im_axis):
        """
        Plots and call auxfun_drag class for moving and removing points.
        """

        # self.canvas = FigureCanvas(self, -1, self.fig1)
        cbar = self.fig1.colorbar(im_axis, ax = self.ax1f1)
        #small hack in case there are any 0 intensity images!

        maxIntensity = np.max(im)
        if maxIntensity == 0:
            maxIntensity = np.max(im) + 255
        cbar.set_ticks(range(12,np.max(im),int(np.floor(maxIntensity/self.num_joints))))
#        cbar.set_ticks(range(12,np.max(im),8))
        cbar.set_ticklabels(self.bodyparts2plot)
        normalize = mcolors.Normalize(vmin=np.min(self.colorparams), vmax=np.max(self.colorparams))
            # Calling auxfun_drag class for moving points around

        for bpindex, bp in enumerate(self.bodyparts2plot):
            color = self.colormap(normalize(bpindex))
            if 'CollectedData_' in self.fileName:
                self.points = [self.Dataframe[self.scorer][bp]['x'].values[self.iter],self.Dataframe[self.scorer][bp]['y'].values[self.iter],1.0]
                self.likelihood = self.points[2]
            else:
                self.points = [self.Dataframe[self.scorer][bp]['x'].values[self.iter],self.Dataframe[self.scorer][bp]['y'].values[self.iter],self.Dataframe[self.scorer][bp]['likelihood'].values[self.iter]]
                self.likelihood = self.points[2]

            if self.move2corner==True:
                ny,nx=np.shape(im)[0],np.shape(im)[1]
                if self.points[0]>nx or self.points[0]<0:
                    self.points[0]=self.center[0]
                if self.points[1]>ny or self.points[1]<0:
                    self.points[1]= self.center[1]

            if not ('CollectedData_' in self.fileName) and self.likelihood < self.threshold:
                circle = [patches.Circle((self.points[0], self.points[1]), radius=self.markerSize, facecolor = 'None', edgecolor = color)]
            else:
                circle = [patches.Circle((self.points[0], self.points[1]), radius=self.markerSize, fc = color, alpha=self.alpha)]

            self.ax1f1.add_patch(circle[0])
            self.dr = auxfun_drag.DraggablePoint(circle[0],bp,self.likelihood,self.adjust_original_labels)
            self.dr.connect()
            self.drs.append(self.dr)
            self.updatedCoords.append(self.dr.coords)
Exemplo n.º 4
0
class MainFrame(wx.Frame):
    """Contains the main GUI and button boxes"""

    def __init__(self, parent, config,Screens,scale_w,scale_h, winHack, img_scale):
        displaysize = wx.GetDisplaySize()

        w = displaysize[0]
        h = displaysize[1]
        self.gui_width = (w*scale_w)/Screens
        self.gui_height = (h*scale_h)


        #print("Scaled GUI width", self.gui_width, "and height", self.gui_height)
        if self.gui_width<600 or self.gui_height<500:
                print("Your screen width", w, "and height", h)
                print("Scaled GUI width", self.gui_width, "and height", self.gui_height)
                print("Please adjust scale_h and scale_w, or get a bigger screen!")
        
        self.size=displaysize
        
        wx.Frame.__init__(self, None, title="DeepLabCut2.0 - Labeling GUI", size=(self.gui_width*winHack, self.gui_height*winHack), style= wx.DEFAULT_FRAME_STYLE)

        self.statusbar = self.CreateStatusBar()
        self.statusbar.SetStatusText("")
        self.Bind(wx.EVT_CHAR_HOOK, self.OnKeyPressed) 

        self.SetBackgroundColour("#ffffff")

        buttons_list = []
        self.Button1 = wx.Button(self, -1, "Load Frames", size=(150, 40), pos=(self.gui_width*.1, self.gui_height*.9))
        self.Button1.Bind(wx.EVT_BUTTON, self.browseDir)
        self.Button1.Enable(True)
        buttons_list.append(self.Button1)

        self.Button5 = wx.Button(self, -1, "Help", size=(80, 40), pos=(self.gui_width*.3, self.gui_height*.9))
        self.Button5.Bind(wx.EVT_BUTTON, self.help)
        self.Button5.Enable(True)
        buttons_list.append(self.Button5)

        self.Button2 = wx.Button(self, -1, "Next Frame", size=(120, 40), pos=(self.gui_width*.4, self.gui_height*.9))
        self.Button2.Bind(wx.EVT_BUTTON, self.nextImage)
        self.Button2.Enable(False)
        buttons_list.append(self.Button2)
        
        self.Button4 = wx.Button(self, -1, "Save", size=(80, 40), pos=(self.gui_width*.6, self.gui_height*.9))
        self.Button4.Bind(wx.EVT_BUTTON, self.save)
        self.Button4.Enable(False)
        self.close = wx.Button(self, -1, "Quit", size=(80, 40), pos=(self.gui_width*.8, self.gui_height*.9))
        self.close.Bind(wx.EVT_BUTTON,self.quitButton)
        buttons_list.append(self.Button4)
        buttons_list.append(self.close)

# add buttons for  zoom
        # radio buttons position: (1250, 65)

        self.Button8 = wx.Button(self,-1,"Zoom", size=(60,30),pos=(self.gui_width*.65, self.gui_height*.85))
        self.Button8.Bind(wx.EVT_BUTTON,self.zoom)
        buttons_list.append(self.Button8)

        self.Button7 = wx.Button(self,-1,"Pan", size=(60,30),pos=(self.gui_width*.75, self.gui_height*.85))
        self.Button7.Bind(wx.EVT_BUTTON,self.pan)
        buttons_list.append(self.Button7)

        self.Button6 = wx.Button(self,-1,"Home", size=(60,30),pos=(self.gui_width*.85, self.gui_height*.85))
        self.Button6.Bind(wx.EVT_BUTTON,self.home)
        buttons_list.append(self.Button6)

        #for btn in buttons_list:
        #    btn.SetBackgroundColour((160, 160, 160))

# Define variables

        self.currentDirectory = os.getcwd()
        self.index = []
        self.iter = []
        self.colormap = cm.hsv #note will be overwritten by colormap from config file during execution
        
        self.file = 0

        self.updatedCoords = []

        self.dataFrame = None
        self.flag = True
        self.file = 0
        self.config_file = config
        self.addLabel = wx.CheckBox(self, label = 'Add new labels to existing dataset?',pos = (self.gui_width*.1, self.gui_height*.85))
        self.addLabel.Bind(wx.EVT_CHECKBOX,self.newLabel)
        self.new_labels = False
        imgW = self.gui_width*img_scale #was 12 inches (perhaps add dpi!)
        imgH = self.gui_height*img_scale    #was 7 inches 

        self.img_size = (imgW, imgH)  # width, height in inches. 
        
    def newLabel(self, event):
        self.chk = event.GetEventObject()
        if self.chk.GetValue() == True:
            self.new_labels = True
            self.addLabel.Enable(False)
        else:
            self.new_labels = False

# BUTTONS FUNCTIONS
    def OnKeyPressed(self, event=None):
        if event.GetKeyCode() == wx.WXK_RIGHT:
            self.nextImage(event=None)

    def zoom(self,event):
        self.statusbar.SetStatusText("Zoom")
        self.toolbar.zoom()
        self.Refresh(eraseBackground=True)
        
    def home(self,event):
        self.statusbar.SetStatusText("Home")
        self.toolbar.home()
        self.Refresh(eraseBackground=True)
         
    def pan(self,event):
        self.statusbar.SetStatusText("Pan")
        self.toolbar.pan()
        self.Refresh(eraseBackground=True)

    def quitButton(self, event):
        """
        Quits the GUI
        """
        self.Destroy()

    def help(self,event):
        """
        Opens Instructions
        """
        wx.MessageBox('1. Select one of the body parts from the radio buttons to add a label (if necessary change config.yaml first to edit the label names). \n\n2. RIGHT clicking on the image will add the selected label. \n The label will be marked as circle filled with a unique color. \n\n3. Hover your mouse over this newly added label to see its name. \n\n4. LEFT click and drag to move the label position. \n\n5. To change the marker size mark the checkbox and move the slider. Uncheck this after it is adjusted! Then advance to the next frame (you cannot zoom or pan again on this image). \n Change the markersize only after finalizing the position of your FIRST LABEL! \n\n6. Once you are happy with the position, select another body part from the radio button. \n Be careful, once you add a new body part, you will not be able to move the old labels. \n\n7. Click Next Frame to move to the next image. \n\n8. When finished labeling all the images, click \'Save\' to save all the labels as a .h5 file. \n\n9. Click OK to continue using the labeling GUI.', 'User instructions', wx.OK | wx.ICON_INFORMATION)

    def onClick(self,event):
        x1 = event.xdata
        y1 = event.ydata
        self.drs = []
        normalize = mcolors.Normalize(vmin=np.min(self.colorparams), vmax=np.max(self.colorparams))
        if event.button == 3:
            if self.rdb.GetSelection() in self.buttonCounter :
                try:
                    new_sel = self.buttonCounter[-1]+1
                    self.rdb.Select(new_sel)
                    self.buttonCounter.append(new_sel)
                except:
                    # fallback: warn user 
                    wx.MessageBox('%s is already annotated. \n Select another body part to annotate.' % (str(self.bodyparts[self.rdb.GetSelection()])), 'Error!', wx.OK | wx.ICON_ERROR)
            
            if self.flag == len(self.bodyparts):
                wx.MessageBox('All body parts are annotated! Click \'Save\' to save the changes. \n Click OK to continue.', 'Done!', wx.OK | wx.ICON_INFORMATION)
                self.canvas.mpl_disconnect(self.onClick)

            color = self.colormap(normalize(self.rdb.GetSelection()))
            circle = [patches.Circle((x1, y1), radius = self.markerSize, fc=color, alpha=0.5)]
            self.num.append(circle)
            self.ax1f1.add_patch(circle[0])
            self.dr = auxfun_drag_label.DraggablePoint(circle[0],self.bodyparts[self.rdb.GetSelection()])
            self.dr.connect()
            self.buttonCounter.append(self.rdb.GetSelection())
            self.dr.coords = [[x1,y1,self.bodyparts[self.rdb.GetSelection()],self.rdb.GetSelection()]]
            self.drs.append(self.dr)
            self.updatedCoords.append(self.dr.coords)
        elif event.button == 2:
            self.zoom(None)
        self.canvas.mpl_disconnect(self.onClick)

    def browseDir(self, event):
        """
        Show the DirDialog and ask the user to change the directory where machine labels are stored
        """
        from skimage import io
        dlg = wx.DirDialog(self, "Choose the directory where your extracted frames are saved:",
                           os.path.join(os.getcwd(), 'labeled-data'), style = wx.DD_DEFAULT_STYLE)
        if dlg.ShowModal() == wx.ID_OK:
            self.dir = dlg.GetPath()
            self.Button1.Enable(False)
            self.Button2.Enable(True)
            self.Button5.Enable(True)
        else:
            dlg.Destroy()
            self.Close(True)
        dlg.Destroy()
        with open(str(self.config_file), 'r') as ymlfile:
            self.cfg = yaml.load(ymlfile)
        self.scorer = self.cfg['scorer']
        self.bodyparts = self.cfg['bodyparts']
        self.videos = self.cfg['video_sets'].keys()
        self.markerSize = self.cfg['dotsize']
        self.colormap = plt.get_cmap(self.cfg['colormap'])
        self.project_path=self.cfg['project_path']
        self.index = glob.glob(os.path.join(self.dir,'*.png'))
        print('Working on folder: {}'.format(os.path.split(str(self.dir))[-1]))
        
        #self.relativeimagenames=self.index ##[n.split(self.project_path+'/')[1] for n in self.index]
        #self.relativeimagenames=[n.split(self.project_path+'/')[1] for n in self.index]
        self.relativeimagenames=['labeled'+n.split('labeled')[1] for n in self.index]
        
        self.fig1, (self.ax1f1) = plt.subplots(figsize=self.img_size,facecolor = "None")
        self.iter = 0
        self.buttonCounter = []
        im = io.imread(self.index[self.iter])

        im_axis = self.ax1f1.imshow(im, self.colormap)

        img_name = Path(self.index[self.iter]).name # self.index[self.iter].split('/')[-1]
        self.ax1f1.set_title(str(str(self.iter+1)+"/"+str(len(self.index)) +" "+ img_name ))
        self.canvas = FigureCanvasWxAgg(self,-1,self.fig1)
        self.toolbar = NavigationToolbar(self.canvas)

        #checks for unique bodyparts
        if len(self.bodyparts)!=len(set(self.bodyparts)):
          print("Error! bodyparts must have unique labels! Please choose unique bodyparts in config.yaml file and try again. Quiting for now!")
          self.Destroy()
          
        if self.new_labels == True:
          self.oldDF = pd.read_hdf(os.path.join(self.dir,'CollectedData_'+self.scorer+'.h5'),'df_with_missing')
          oldBodyParts = self.oldDF.columns.get_level_values(1)
          _, idx = np.unique(oldBodyParts, return_index=True)
          oldbodyparts2plot =  list(oldBodyParts[np.sort(idx)])
          self.bodyparts =  list(set(self.bodyparts) - set(oldbodyparts2plot))
          self.rdb = wx.RadioBox(self, id=1, label="Select a body part to annotate",pos=(self.gui_width*.83, self.gui_height*.1), choices=self.bodyparts, majorDimension =1,style=wx.RA_SPECIFY_COLS,validator=wx.DefaultValidator, name=wx.RadioBoxNameStr)
          self.option = self.rdb.Bind(wx.EVT_RADIOBOX,self.onRDB)
          cbar = self.fig1.colorbar(im_axis, ax = self.ax1f1)
          cbar.set_ticks(range(12,np.max(im),int(np.floor(np.max(im)/len(self.bodyparts)-1))))
          cbar.set_ticklabels(self.bodyparts)
        else:
          self.addLabel.Enable(False)
          cbar = self.fig1.colorbar(im_axis, ax = self.ax1f1)
          cbar.set_ticks(range(12,np.max(im),int(np.floor(np.max(im)/len(self.bodyparts)-1))))
          cbar.set_ticklabels(self.bodyparts)
          self.rdb = wx.RadioBox(self, id=1, label="Select a body part to annotate",pos=(self.gui_width*.83, self.gui_height*.1), choices=self.bodyparts, majorDimension =1,style=wx.RA_SPECIFY_COLS,validator=wx.DefaultValidator, name=wx.RadioBoxNameStr)
          self.option = self.rdb.Bind(wx.EVT_RADIOBOX,self.onRDB)


        self.cidClick = self.canvas.mpl_connect('button_press_event', self.onClick)
        self.flag = 0
        self.num = []
        self.counter = []
        self.presentCoords = []

        self.colorparams = list(range(0,len(self.bodyparts)+1))

        a = np.empty((len(self.index),2,))
        a[:] = np.nan
        for bodypart in self.bodyparts:
            index = pd.MultiIndex.from_product([[self.scorer], [bodypart], ['x', 'y']],names=['scorer', 'bodyparts', 'coords'])
            #frame = pd.DataFrame(a, columns = index, index = self.index)
            frame = pd.DataFrame(a, columns = index, index = self.relativeimagenames)
            self.dataFrame = pd.concat([self.dataFrame, frame],axis=1)

        if self.file == 0:
            self.checkBox = wx.CheckBox(self, label = 'Adjust marker size.',pos = (self.gui_width*.43, self.gui_height*.85))
            self.checkBox.Bind(wx.EVT_CHECKBOX,self.onChecked)
            self.slider = wx.Slider(self, -1, 18, 0, 20,size=(200, -1),  pos=(self.gui_width*.40, self.gui_height*.78),style=wx.SL_HORIZONTAL | wx.SL_AUTOTICKS | wx.SL_LABELS )
            self.slider.Bind(wx.EVT_SLIDER, self.OnSliderScroll)
            self.slider.Enable(True)

    def onRDB(self,event):
       self.option = self.rdb.GetSelection()
       self.counter.append(self.option)

    def nextImage(self,event):
        """
        Moves to next image
        """
        from skimage import io
        # Checks for the last image and disables the Next button + diesbt load the next if RIGHT arrow key pressed
        if len(self.index) - self.iter == 1:
            self.Button2.Enable(False)
            self.Button4.Enable(True)
            return

        self.file = 1
        MainFrame.saveEachImage(self)
        self.canvas.Destroy()
        plt.close(self.fig1)
        self.ax1f1.clear()
        self.iter = self.iter + 1
        #Refreshing the button counter
        self.buttonCounter = []
        self.rdb.SetSelection(0)
        self.fig1, (self.ax1f1) = plt.subplots(figsize=self.img_size,facecolor = "None")

        if len(self.index) > self.iter:
            self.updatedCoords = []
            #read the image
            im = io.imread(self.index[self.iter])
            #Plotting
            im_axis = self.ax1f1.imshow(im,self.colormap)
            cbar = self.fig1.colorbar(im_axis, ax = self.ax1f1)
            cbar.set_ticks(range(12,np.max(im),int(np.floor(np.max(im)/len(self.bodyparts)))))
            cbar.set_ticklabels(self.bodyparts)
            img_name = Path(self.index[self.iter]).name # self.index[self.iter].split('/')[-1]
            self.ax1f1.set_title(str(str(self.iter)+"/"+str(len(self.index)-1) +" "+ img_name ))
            self.canvas = FigureCanvasWxAgg(self, -1, self.fig1)
            self.cidClick = self.canvas.mpl_connect('button_press_event', self.onClick)

        # Recreate toolbar for zooming
        self.toolbar = NavigationToolbar(self.canvas)

    def saveEachImage(self):
        """
        Saves data for each image
        """
        plt.close(self.fig1)

        for idx, bp in enumerate(self.updatedCoords):
            self.dataFrame.loc[self.relativeimagenames[self.iter]][self.scorer, bp[0][-2],'x' ] = bp[-1][0]
            self.dataFrame.loc[self.relativeimagenames[self.iter]][self.scorer, bp[0][-2],'y' ] = bp[-1][1]

    def save(self,event):
        """
        Saves the final dataframe
        """
        MainFrame.saveEachImage(self)
        if self.new_labels == True:
            self.dataFrame = pd.concat([self.oldDF,self.dataFrame],axis=1)
        # Windows compatible
        self.dataFrame.to_csv(os.path.join(self.dir,"CollectedData_" + self.scorer + ".csv"))
        self.dataFrame.to_hdf(os.path.join(self.dir,"CollectedData_" + self.scorer + '.h5'),'df_with_missing',format='table', mode='w')

        nextFilemsg = wx.MessageBox('File saved. Do you want to label another data set?', 'Repeat?', wx.YES_NO | wx.ICON_INFORMATION)
        if nextFilemsg == 2:
            self.file = 1
            plt.close(self.fig1)
            self.canvas.Destroy()
            self.rdb.Destroy()
            self.buttonCounter = []
            self.updatedCoords = []
            self.dataFrame = None
            self.counter = []
            self.bodyparts = []
            self.Button1.Enable(True)
            self.slider.Enable(False)
            self.checkBox.Enable(False)
            self.new_labels = self.new_labels
            MainFrame.browseDir(self, event)
        else:
            self.Destroy()
            print("You can now check the labels, using 'check_labels' before proceeding. Then,  you can use the function 'create_training_dataset' to create the training dataset.")

    def onChecked(self, event):
      self.cb = event.GetEventObject()
      if self.cb.GetValue() == True:
          self.slider.Enable(True)
          self.cidClick = self.canvas.mpl_connect('button_press_event', self.onClick)
      else:
          self.slider.Enable(False)

    def OnSliderScroll(self, event):
        """
        Adjust marker size for plotting the annotations
        """
        from skimage import io
        self.drs = []
        plt.close(self.fig1)
        self.canvas.Destroy()
        self.fig1, (self.ax1f1) = plt.subplots(figsize=self.img_size,facecolor = "None")
        self.markerSize = (self.slider.GetValue())
        im = io.imread(self.index[self.iter])
        im_axis = self.ax1f1.imshow(im,self.colormap)
        cbar = self.fig1.colorbar(im_axis, ax = self.ax1f1)
        cbar.set_ticks(range(12,np.max(im),int(np.floor(np.max(im)/len(self.bodyparts)))))
        cbar.set_ticklabels(self.bodyparts)
        img_name = Path(self.index[self.iter]).name #self.index[self.iter].split('/')[-1]
        self.ax1f1.set_title(str(str(self.iter)+"/"+str(len(self.index)-1) +" "+ img_name ))
        self.canvas = FigureCanvasWxAgg(self, -1, self.fig1)
        normalize = mcolors.Normalize(vmin=np.min(self.colorparams), vmax=np.max(self.colorparams))

        for idx, bp in enumerate(self.updatedCoords):
            col = self.updatedCoords[idx][-1][-1]
            #color = self.colormap(normalize(col))
            x1 = self.updatedCoords[idx][-1][0]
            y1 = self.updatedCoords[idx][-1][1]
            circle = [patches.Circle((x1, y1), radius=self.markerSize, alpha=0.5)]
            self.ax1f1.add_patch(circle[0])
            self.cidClick = self.canvas.mpl_connect('button_press_event', self.onClick)
Exemplo n.º 5
0
class RecordWindow(wx.Frame):
    def __init__(self, parent):
        self.stream = None
        self.p = None
        self.parent = parent
        self.dc = None
        self.synth = None

        super().__init__(parent=None, title='Audio Recorder')
        plt.style.use('dark_background')

        #prep input stream for audio
        self.p = pyaudio.PyAudio()
        self.stream = self.p.open(format=pyaudio.paFloat32,
                                  channels=1,
                                  rate=RATE,
                                  input=True,
                                  frames_per_buffer=CHUNK)

        win_s = 4096
        hop_s = CHUNK
        self.notes_o = aubio.notes("default", win_s, hop_s, RATE)

        #prep variables for plotting
        self.xs = []
        self.ys = []

        self.fig = plt.figure()
        self.ax = plt.axes(xlim=(0, 100), ylim=(0, 2000))
        self.line, = self.ax.plot([], [])
        self.line.set_data(self.xs, self.ys)

        # scrape midi number to note conversion data online
        self.miditonote = {}
        url = "https://www.inspiredacoustics.com/en/MIDI_note_numbers_and_center_frequencies"
        page = get(url, timeout=5)
        soup = BeautifulSoup(page.content, 'html.parser')
        table = soup.find('table')
        rows = table.find_all('tr')

        for row in rows:
            col = row.find_all('td')
            if (len(col) > 0):
                if col[0].text.isdigit():
                    midinumber = int(col[0].text)
                    self.miditonote[midinumber] = col[3].text

        #prep GUI
        self.panel = wx.Panel(self, size=(780, 480))
        self.canvas = FigureCanvas(self.panel, -1, self.fig)

        #sizer for graph
        self.graphsizer = wx.BoxSizer(wx.VERTICAL)
        self.graphsizer.Add(self.canvas, 1, wx.TOP | wx.LEFT | wx.GROW)

        #sizer for note and audio
        self.audiosizer = wx.BoxSizer(wx.VERTICAL)
        self.notetext = wx.StaticText(self.panel, style=wx.ALIGN_LEFT)
        self.notetext.SetForegroundColour('blue')
        self.notebox = wx.StaticBox(self.panel, size=(150, 50))
        self.noteboxsizer = wx.StaticBoxSizer(self.notebox, wx.VERTICAL)
        self.noteboxsizer.Add(self.notetext, 0, wx.ALIGN_LEFT)
        self.audiobutton = wx.Button(self.panel, -1, "Pause")
        self.synthbutton = wx.Button(self.panel, -1, "Open Synthesizer")
        self.audiosizer.AddSpacer(10)
        self.audiosizer.Add(self.audiobutton,
                            0,
                            wx.ALIGN_CENTER | wx.ALL,
                            border=5)
        self.audiosizer.AddSpacer(10)
        self.audiosizer.Add(self.noteboxsizer,
                            0,
                            wx.ALIGN_CENTER | wx.ALL,
                            border=5)
        self.audiosizer.AddSpacer(280)
        self.audiosizer.Add(self.synthbutton,
                            0,
                            wx.ALIGN_CENTER | wx.ALL,
                            border=5)

        self.audiobutton.Bind(wx.EVT_BUTTON, self.pause_play)
        self.Bind(wx.EVT_PAINT, self.OnPaint)
        self.synthbutton.Bind(wx.EVT_BUTTON, self.open_synth)

        #add both components to 1 sizer
        self.mainsizer = wx.BoxSizer(wx.HORIZONTAL)
        self.mainsizer.Add(self.graphsizer, 1)
        self.mainsizer.Add(self.audiosizer, 1, wx.RIGHT)
        self.panel.SetSizer(self.mainsizer)
        self.Fit()
        self.panel.Layout()

        plt.ion()
        #prep timer
        self.timercount = 0
        self.timer = wx.Timer(self)
        self.Bind(wx.EVT_TIMER, self.record, self.timer)
        self.Bind(wx.EVT_CLOSE, self.OnClose)
        self.timer.Start(100)

    def record(self, event):
        if (self.stream and self.stream.is_active()):
            data = np.frombuffer(self.stream.read(1024,
                                                  exception_on_overflow=False),
                                 dtype=np.float32)
            note = self.notes_o(data)[0]
            peak = np.average(np.abs(data)) * 2000
            # bars = "#" * int(50 * peak / 2 ** 16)

            if (note != 0):
                self.notetext.SetLabel(self.miditonote.get(note))
                if self.synth:
                    self.synth.input_note.SetValue(self.miditonote.get(note))

            #print("%04d %04d %.1f %s" % (self.timercount, peak, note, self.miditonote.get(note)))
            self.xs.append(self.timercount)
            self.timercount += 1
            self.ys.append(peak)
            self.line.set_data(self.xs, self.ys)
            self.ax.relim()
            self.ax.autoscale()
            plt.plot()

    def pause_play(self, event):
        if (self.audiobutton.GetLabel() == "Pause"):
            self.stream.stop_stream()
            self.audiobutton.SetLabel("Play")
        elif (self.audiobutton.GetLabel() == "Play"):
            self.stream.start_stream()
            self.audiobutton.SetLabel("Pause")

    def open_synth(self, event):
        if not self.parent.synth and not self.synth:
            self.synth = SynthWindow(self)

    def OnPaint(self, event):
        # draw octave bar
        self.dc = wx.PaintDC(self)
        self.dc.Clear()
        #grab note octave number
        if self.notetext.GetLabel() != "":
            octaveno = int(self.notetext.GetLabel()[-1])
        else:
            octaveno = -1

        # set pen and brush color to rainbow color
        self.dc.SetPen(wx.Pen("red", style=wx.SOLID))
        if octaveno > 6:
            self.dc.SetBrush(wx.Brush("red", style=wx.SOLID))
        else:
            self.dc.SetBrush(wx.NullBrush)
        self.dc.DrawRectangle(720, 150, 50, 25)

        self.dc.SetPen(wx.Pen(wx.Colour(255, 165, 0), style=wx.SOLID))
        if octaveno > 5:
            self.dc.SetBrush(wx.Brush(wx.Colour(255, 165, 0), style=wx.SOLID))
        else:
            self.dc.SetBrush(wx.NullBrush)
        self.dc.DrawRectangle(720, 180, 50, 25)

        self.dc.SetPen(wx.Pen("yellow", style=wx.SOLID))
        if octaveno > 4:
            self.dc.SetBrush(wx.Brush("yellow", style=wx.SOLID))
        else:
            self.dc.SetBrush(wx.NullBrush)
        self.dc.DrawRectangle(720, 210, 50, 25)

        self.dc.SetPen(wx.Pen(wx.Colour(0, 255, 0), style=wx.SOLID))
        if octaveno > 3:
            self.dc.SetBrush(wx.Brush(wx.Colour(0, 255, 0), style=wx.SOLID))
        else:
            self.dc.SetBrush(wx.NullBrush)
        self.dc.DrawRectangle(720, 240, 50, 25)

        self.dc.SetPen(wx.Pen("blue", style=wx.SOLID))
        if octaveno > 2:
            self.dc.SetBrush(wx.Brush("blue", style=wx.SOLID))
        else:
            self.dc.SetBrush(wx.NullBrush)
        self.dc.DrawRectangle(720, 270, 50, 25)

        self.dc.SetPen(wx.Pen(wx.Colour(63, 0, 255), style=wx.SOLID))
        if octaveno > 1:
            self.dc.SetBrush(wx.Brush(wx.Colour(63, 0, 255), style=wx.SOLID))
        else:
            self.dc.SetBrush(wx.NullBrush)
        self.dc.DrawRectangle(720, 300, 50, 25)

        self.dc.SetPen(wx.Pen(wx.Colour(128, 0, 128), style=wx.SOLID))
        if octaveno > 0:
            self.dc.SetBrush(wx.Brush(wx.Colour(128, 0, 128), style=wx.SOLID))
        else:
            self.dc.SetBrush(wx.NullBrush)
        self.dc.DrawRectangle(720, 330, 50, 25)

        self.panel.Refresh()

    def OnClose(self, event):
        if self.stream:
            self.stream.stop_stream()
            self.stream.close()
        if self.p:
            self.p.terminate()
        if self.canvas:
            self.canvas.Destroy()
        if self.dc:
            self.dc.Destroy()
        if self.synth:
            self.synth.Destroy()
        plt.close('all')
        self.Destroy()
        self.parent.audio = None
        self.parent.Destroy()
Exemplo n.º 6
0
class MatplotPanel(wx.Panel):
    """
    From http://stackoverflow.com/a/19898295/1109980
    """
    def __init__(self, *args, **kwargs):
        wx.Panel.__init__(self, *args, **kwargs)

        self.subplots = None
        self.rows = None
        self.cols = None
        self.yerr = None
        self.values = None

        self.sizer = wx.BoxSizer(wx.VERTICAL)
        self.SetSizer(self.sizer)

        plt = plot.Plot()
        self.canvas = FigureCanvas(self, -1, plt.fig)
        self.toolbar = NavigationToolbar(self.canvas)

        self.sizer.Add(self.toolbar, 0, wx.EXPAND)
        self.sizer.Add(self.canvas, 1, wx.GROW)
        self.Fit()

    def redraw(self, agg):
        self.canvas.Destroy()
        self.toolbar.Destroy()
        plt = plot.Plot()
        kind = self.frame.plot_type.GetItemLabel(
            self.frame.plot_type.GetSelection())
        errkind = self.frame.err_type.GetItemLabel(
            self.frame.err_type.GetSelection())
        plt.plot(agg, kind=kind, errkind=errkind.lower())  #, within='rows')

        self.canvas = FigureCanvas(self, -1, plt.fig)
        self.toolbar = NavigationToolbar(self.canvas)

        self.sizer.Add(self.toolbar, 0, wx.EXPAND)
        self.sizer.Add(self.canvas, 1, wx.LEFT | wx.TOP | wx.GROW)
        self.Layout()

    def draw_empty(self):
        self.canvas.Destroy()
        self.toolbar.Destroy()
        plt = plot.Plot()

        self.canvas = FigureCanvas(self, -1, plt.fig)
        self.toolbar = NavigationToolbar(self.canvas)

        self.sizer.Add(self.toolbar, 0, wx.EXPAND)
        self.sizer.Add(self.canvas, 1, wx.GROW)
        self.Layout()

    def _get_items(self, parent, event=None):
        """
        #Parent is the list where the item is being inserted
        """
        items = []
        if parent.GetItemCount() > 0:
            items += [
                parent.GetItemText(i) for i in range(parent.GetItemCount())
            ]

        if event is not None:
            if event.GetEventType() == wx.EVT_LIST_DELETE_ITEM.evtType[0]:
                if parent == event.GetEventObject():
                    idx = items.index(event.GetText())
                    del items[idx]

        if len(items) == 0:
            items = None
        return items

    def changePlot(self, event=None):
        self.subplots = self._get_items(self.frame.list_subplots, event)
        self.rows = self._get_items(self.frame.list_rows, event)
        self.cols = self._get_items(self.frame.list_cols, event)
        self.values = self._get_items(self.frame.list_values, event)
        self.yerr = self._get_items(self.frame.list_yerr, event)
        self.plot()

    def plot(self):
        if self.values is None or (self.cols is None and self.rows is None):
            self.draw_empty()
        else:
            value_type = self.frame.value_type.GetItemLabel(
                self.frame.value_type.GetSelection())
            if value_type == 'metric':
                agg = stats.aggregate(self.df,
                                      subplots=self.subplots,
                                      rows=self.rows,
                                      cols=self.cols,
                                      yerr=self.yerr,
                                      values=self.values)
            elif value_type == 'accuracy':
                correct = list(
                    self.frame.panel_corr.check_correct.GetCheckedStrings())
                incorrect = list(
                    self.frame.panel_corr.check_incorrect.GetCheckedStrings())
                agg = stats.accuracy(self.df,
                                     subplots=self.subplots,
                                     rows=self.rows,
                                     cols=self.cols,
                                     yerr=self.yerr,
                                     values=self.values,
                                     correct=correct,
                                     incorrect=incorrect)
                #import pdb; pdb.set_trace()
            self.redraw(agg)
            self.frame.list_agg.DeleteAllItems()
            for i in range(self.frame.list_agg.GetColumnCount()):
                self.frame.list_agg.DeleteColumn(0)

            aggr = self.frame.list_agg.stack(agg)
            self.frame.list_agg.set_data(aggr)
            self.frame.aggr = aggr