class MainFrame(BaseFrame): def __init__(self, parent, config, imtypes, config3d, sourceCam, jump_unlabeled): super(MainFrame, self).__init__("DeepLabCut2.0 - Labeling ToolBox", parent, imtypes) self.jump_unlabeled = jump_unlabeled self.statusbar.SetStatusText( "Looking for a folder to start labeling. Click 'Load frames' to begin." ) self.Bind(wx.EVT_CHAR_HOOK, self.OnKeyPressed) ################################################################################################################################################### # Splitting the frame into top and bottom panels. Bottom panels contains the widgets. The top panel is for showing images and plotting! topSplitter = wx.SplitterWindow(self) vSplitter = wx.SplitterWindow(topSplitter) self.image_panel = ImagePanel(vSplitter, config, config3d, sourceCam, self.gui_size) self.choice_panel = ScrollPanel(vSplitter) vSplitter.SplitVertically(self.image_panel, self.choice_panel, sashPosition=self.gui_size[0] * 0.8) vSplitter.SetSashGravity(1) self.widget_panel = WidgetPanel(topSplitter) topSplitter.SplitHorizontally(vSplitter, self.widget_panel, sashPosition=self.gui_size[1] * 0.83) # 0.9 topSplitter.SetSashGravity(1) sizer = wx.BoxSizer(wx.VERTICAL) sizer.Add(topSplitter, 1, wx.EXPAND) self.SetSizer(sizer) ################################################################################################################################################### # Add Buttons to the WidgetPanel and bind them to their respective functions. widgetsizer = wx.WrapSizer(orient=wx.HORIZONTAL) self.load = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Load frames") widgetsizer.Add(self.load, 1, wx.ALL, 15) self.load.Bind(wx.EVT_BUTTON, self.browseDir) self.prev = wx.Button(self.widget_panel, id=wx.ID_ANY, label="<<Previous") widgetsizer.Add(self.prev, 1, wx.ALL, 15) self.prev.Bind(wx.EVT_BUTTON, self.prevImage) self.prev.Enable(False) self.next = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Next>>") widgetsizer.Add(self.next, 1, wx.ALL, 15) self.next.Bind(wx.EVT_BUTTON, self.nextImage) self.next.Enable(False) self.help = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Help") widgetsizer.Add(self.help, 1, wx.ALL, 15) self.help.Bind(wx.EVT_BUTTON, self.helpButton) self.help.Enable(True) # self.zoom = wx.ToggleButton(self.widget_panel, label="Zoom") widgetsizer.Add(self.zoom, 1, wx.ALL, 15) self.zoom.Bind(wx.EVT_TOGGLEBUTTON, self.zoomButton) self.widget_panel.SetSizer(widgetsizer) self.zoom.Enable(False) self.home = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Home") widgetsizer.Add(self.home, 1, wx.ALL, 15) self.home.Bind(wx.EVT_BUTTON, self.homeButton) self.widget_panel.SetSizer(widgetsizer) self.home.Enable(False) self.pan = wx.ToggleButton(self.widget_panel, id=wx.ID_ANY, label="Pan") widgetsizer.Add(self.pan, 1, wx.ALL, 15) self.pan.Bind(wx.EVT_TOGGLEBUTTON, self.panButton) self.widget_panel.SetSizer(widgetsizer) self.pan.Enable(False) self.lock = wx.CheckBox(self.widget_panel, id=wx.ID_ANY, label="Lock View") widgetsizer.Add(self.lock, 1, wx.ALL, 15) self.lock.Bind(wx.EVT_CHECKBOX, self.lockChecked) self.widget_panel.SetSizer(widgetsizer) self.lock.Enable(False) self.save = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Save") widgetsizer.Add(self.save, 1, wx.ALL, 15) self.save.Bind(wx.EVT_BUTTON, self.saveDataSet) self.save.Enable(False) widgetsizer.AddStretchSpacer(15) self.quit = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Quit") widgetsizer.Add(self.quit, 1, wx.ALL, 15) self.quit.Bind(wx.EVT_BUTTON, self.quitButton) self.widget_panel.SetSizer(widgetsizer) self.widget_panel.SetSizerAndFit(widgetsizer) self.widget_panel.Layout() ############################################################################################################################### # Variables initialization self.currentDirectory = os.getcwd() self.index = [] self.iter = [] self.file = 0 self.updatedCoords = [] self.dataFrame = None self.config_file = config self.new_labels = False self.buttonCounter = [] self.bodyparts2plot = [] self.drs = [] self.num = [] self.view_locked = False # Workaround for MAC - xlim and ylim changed events seem to be triggered too often so need to make sure that the # xlim and ylim have actually changed before turning zoom off self.prezoom_xlim = [] self.prezoom_ylim = [] ############################################################################################################################### # BUTTONS FUNCTIONS FOR HOTKEYS def OnKeyPressed(self, event=None): if event.GetKeyCode() == wx.WXK_RIGHT: self.nextImage(event=None) elif event.GetKeyCode() == wx.WXK_LEFT: self.prevImage(event=None) elif event.GetKeyCode() == wx.WXK_DOWN: self.nextLabel(event=None) elif event.GetKeyCode() == wx.WXK_UP: self.previousLabel(event=None) elif event.GetKeyCode() == wx.WXK_BACK: pos_abs = event.GetPosition() inv = self.axes.transData.inverted() pos_rel = list(inv.transform(pos_abs)) y1, y2 = self.axes.get_ylim() pos_rel[1] = y1 - pos_rel[1] + y2 # Recall y-axis is inverted i = np.nanargmin([ self.calc_distance(*dp.point.center, *pos_rel) for dp in self.drs ]) closest_dp = self.drs[i] msg = wx.MessageBox( "Do you want to remove the label %s ?" % closest_dp.bodyParts, "Remove!", wx.YES_NO | wx.ICON_WARNING, ) if msg == 2: closest_dp.delete_data() self.buttonCounter.remove( self.bodyparts.index(closest_dp.bodyParts)) elif event.ControlDown() and event.GetKeyCode() == 67: self.duplicate_labels() def duplicate_labels(self): if self.iter >= 1: curr_image = self.relativeimagenames[self.iter] prev_image = self.relativeimagenames[self.iter - 1] self.dataFrame.loc[curr_image] = self.dataFrame.loc[ prev_image].values img_name = Path(self.index[self.iter]).name ( self.figure, self.axes, self.canvas, self.toolbar, ) = self.image_panel.drawplot( self.img, img_name, self.iter, self.index, self.bodyparts, self.colormap, keep_view=self.view_locked, ) self.buttonCounter = MainFrame.plot(self, self.img) def activateSlider(self, event): """ Activates the slider to increase the markersize """ self.checkSlider = event.GetEventObject() if self.checkSlider.GetValue(): self.activate_slider = True self.slider.Enable(True) MainFrame.updateZoomPan(self) else: self.slider.Enable(False) def OnSliderScroll(self, event): """ Adjust marker size for plotting the annotations """ MainFrame.saveEachImage(self) MainFrame.updateZoomPan(self) self.buttonCounter = [] self.markerSize = self.slider.GetValue() img_name = Path(self.index[self.iter]).name self.figure.delaxes(self.figure.axes[1]) self.figure, self.axes, self.canvas, self.toolbar = self.image_panel.drawplot( self.img, img_name, self.iter, self.index, self.bodyparts, self.colormap, keep_view=True, ) self.axes.callbacks.connect("xlim_changed", self.onZoom) self.axes.callbacks.connect("ylim_changed", self.onZoom) self.buttonCounter = MainFrame.plot(self, self.img) def quitButton(self, event): """ Asks user for its inputs and then quits the GUI """ self.statusbar.SetStatusText("Qutting now!") nextFilemsg = wx.MessageBox( "Do you want to label another data set?", "Repeat?", wx.YES_NO | wx.ICON_INFORMATION, ) if nextFilemsg == 2: self.file = 1 self.buttonCounter = [] self.updatedCoords = [] self.dataFrame = None self.bodyparts = [] self.new_labels = self.new_labels self.axes.clear() self.figure.delaxes(self.figure.axes[1]) self.choiceBox.Clear(True) MainFrame.updateZoomPan(self) MainFrame.browseDir(self, event) self.save.Enable(True) else: self.Destroy() print( "You can now check the labels, using 'check_labels' before proceeding. Then, you can use the function 'create_training_dataset' to create the training dataset." ) def helpButton(self, event): """ Opens Instructions """ MainFrame.updateZoomPan(self) wx.MessageBox( "1. Select an individual and one of the body parts from the radio buttons to add a label (if necessary change config.yaml first to edit the label names). \n\n2. Right clicking on the image will add the selected label and the next available label will be selected from the radio button. \n The label will be marked as circle filled with a unique color (and individual ID a unique color on the rim).\n\n3. To change the marker size, mark the checkbox and move the slider, then uncheck the box. \n\n4. Hover your mouse over this newly added label to see its name. \n\n5. Use left click and drag to move the label position. \n\n6. Once you are happy with the position, right click to add the next available label. You can always reposition the old labels, if required. You can delete a label with the middle button mouse click (or click 'delete' key). \n\n7. Click Next/Previous to move to the next/previous image (or hot-key arrows left and right).\n User can also re-label a deleted point by going to a previous/next image then returning to the current image. \n NOTE: the user cannot add a label if the label is already present. \n \n8. You can click Cntrl+C to copy+paste labels from a previous image into the current image. \n\n9. When finished labeling all the images, click 'Save' to save all the labels as a .h5 file. \n\n10. Click OK to continue using the labeling GUI. For more tips and hotkeys: see docs!!", "User instructions", wx.OK | wx.ICON_INFORMATION, ) self.statusbar.SetStatusText("Help") def onButtonRelease(self, event): if self.pan.GetValue(): self.updateZoomPan() self.statusbar.SetStatusText("Pan Off") def onClick(self, event): """ This function adds labels and auto advances to the next label. """ x1 = event.xdata y1 = event.ydata if event.button == 3: if self.rdb.GetSelection() in self.buttonCounter: wx.MessageBox( "%s is already annotated. \n Select another body part to annotate." % (str(self.bodyparts[self.rdb.GetSelection()])), "Error!", wx.OK | wx.ICON_ERROR, ) else: color = self.colormap( self.norm(self.colorIndex[self.rdb.GetSelection()])) circle = [ patches.Circle((x1, y1), radius=self.markerSize, fc=color, alpha=self.alpha) ] self.num.append(circle) self.axes.add_patch(circle[0]) self.dr = auxfun_drag.DraggablePoint( circle[0], self.bodyparts[self.rdb.GetSelection()]) self.dr.connect() self.buttonCounter.append(self.rdb.GetSelection()) self.dr.coords = [[ x1, y1, self.bodyparts[self.rdb.GetSelection()], self.rdb.GetSelection(), ]] self.drs.append(self.dr) self.updatedCoords.append(self.dr.coords) if self.rdb.GetSelection() < len(self.bodyparts) - 1: self.rdb.SetSelection(self.rdb.GetSelection() + 1) self.figure.canvas.draw() self.canvas.mpl_disconnect(self.onClick) self.canvas.mpl_disconnect(self.onButtonRelease) def nextLabel(self, event): """ This function is to create a hotkey to skip down on the radio button panel. """ if self.rdb.GetSelection() < len(self.bodyparts) - 1: self.rdb.SetSelection(self.rdb.GetSelection() + 1) def previousLabel(self, event): """ This function is to create a hotkey to skip up on the radio button panel. """ if self.rdb.GetSelection() > 0: self.rdb.SetSelection(self.rdb.GetSelection() - 1) def browseDir(self, event): """ Show the DirDialog and ask the user to change the directory where machine labels are stored """ if self.jump_unlabeled: self.dir = str( auxiliaryfunctions.find_next_unlabeled_folder( self.config_file)) else: self.statusbar.SetStatusText( "Looking for a folder to start labeling...") cwd = os.path.join(os.getcwd(), "labeled-data") dlg = wx.DirDialog( self, "Choose the directory where your extracted frames are saved:", cwd, style=wx.DD_DEFAULT_STYLE, ) if dlg.ShowModal() != wx.ID_OK: dlg.Destroy() self.Close(True) return self.dir = dlg.GetPath() dlg.Destroy() self.load.Enable(False) self.next.Enable(True) self.save.Enable(True) # Enabling the zoom, pan and home buttons self.zoom.Enable(True) self.home.Enable(True) self.pan.Enable(True) self.lock.Enable(True) # Reading config file and its variables self.cfg = auxiliaryfunctions.read_config(self.config_file) self.scorer = self.cfg["scorer"] self.bodyparts = self.cfg["bodyparts"] self.videos = self.cfg["video_sets"].keys() self.markerSize = self.cfg["dotsize"] self.alpha = self.cfg["alphavalue"] self.colormap = plt.get_cmap(self.cfg["colormap"]) self.colormap = self.colormap.reversed() self.project_path = self.cfg["project_path"] imlist = [] for imtype in self.imtypes: imlist.extend([ fn for fn in glob.glob(os.path.join(self.dir, imtype)) if ("labeled.png" not in fn) ]) if len(imlist) == 0: print("No images found!!") self.index = np.sort(imlist) self.statusbar.SetStatusText("Working on folder: {}".format( os.path.split(str(self.dir))[-1])) relativeimagenames = [ "labeled" + n.split("labeled")[1] for n in self.index ] # [n.split(self.project_path+'/')[1] for n in self.index] self.relativeimagenames = [ tuple(name.split(os.path.sep)) for name in relativeimagenames ] # Reading the existing dataset,if already present try: self.dataFrame = pd.read_hdf( os.path.join(self.dir, "CollectedData_" + self.scorer + ".h5")) conversioncode.guarantee_multiindex_rows(self.dataFrame) self.dataFrame.sort_index(inplace=True) self.prev.Enable(True) # Finds the first empty row in the dataframe and sets the iteration to that index for idx, j in enumerate(self.dataFrame.index): values = self.dataFrame.loc(axis=0)[j].values if np.prod(np.isnan(values)) == 1: self.iter = idx break else: self.iter = 0 except: a = np.empty((len(self.index), 2)) a[:] = np.nan for bodypart in self.bodyparts: cols = pd.MultiIndex.from_product( [[self.scorer], [bodypart], ["x", "y"]], names=["scorer", "bodyparts", "coords"], ) index = pd.MultiIndex.from_tuples(self.relativeimagenames) frame = pd.DataFrame(a, columns=cols, index=index) self.dataFrame = pd.concat([self.dataFrame, frame], axis=1) self.iter = 0 # Reading the image name self.img = os.path.join(*self.dataFrame.index[self.iter]) img_name = Path(self.img).name self.norm, self.colorIndex = getColorIndices(self.img, self.bodyparts) # Checking for new frames and adding them to the existing dataframe old_imgs = sorted(self.dataFrame.index) self.newimages = list(set(self.relativeimagenames) - set(old_imgs)) if self.newimages: print("Found new frames..") # Create an empty dataframe with all the new images and then merge this to the existing dataframe. self.df = None a = np.empty((len(self.newimages), 2)) a[:] = np.nan for bodypart in self.bodyparts: cols = pd.MultiIndex.from_product( [[self.scorer], [bodypart], ["x", "y"]], names=["scorer", "bodyparts", "coords"], ) index = pd.MultiIndex.from_tuples(self.newimages) frame = pd.DataFrame(a, columns=cols, index=index) self.df = pd.concat([self.df, frame], axis=1) self.dataFrame = pd.concat([self.dataFrame, self.df], axis=0) # Sort it by the index values self.dataFrame.sort_index(inplace=True) # checks for unique bodyparts if len(self.bodyparts) != len(set(self.bodyparts)): print( "Error - bodyparts must have unique labels! Please choose unique bodyparts in config.yaml file and try again. Quitting for now!" ) self.Close(True) # Extracting the list of new labels oldBodyParts = self.dataFrame.columns.get_level_values(1) _, idx = np.unique(oldBodyParts, return_index=True) oldbodyparts2plot = list(oldBodyParts[np.sort(idx)]) self.new_bodyparts = [ x for x in self.bodyparts if x not in oldbodyparts2plot ] # Checking if user added a new label if not self.new_bodyparts: # i.e. no new label ( self.figure, self.axes, self.canvas, self.toolbar, ) = self.image_panel.drawplot(self.img, img_name, self.iter, self.index, self.bodyparts, self.colormap) self.axes.callbacks.connect("xlim_changed", self.onZoom) self.axes.callbacks.connect("ylim_changed", self.onZoom) ( self.choiceBox, self.rdb, self.slider, self.checkBox, ) = self.choice_panel.addRadioButtons(self.bodyparts, self.file, self.markerSize) self.buttonCounter = MainFrame.plot(self, self.img) self.cidClick = self.canvas.mpl_connect("button_press_event", self.onClick) self.canvas.mpl_connect("button_release_event", self.onButtonRelease) else: dlg = wx.MessageDialog( None, "New label found in the config file. Do you want to see all the other labels?", "New label found", wx.YES_NO | wx.ICON_WARNING, ) result = dlg.ShowModal() if result == wx.ID_NO: self.bodyparts = self.new_bodyparts self.norm, self.colorIndex = getColorIndices( self.img, self.bodyparts) a = np.empty((len(self.index), 2)) a[:] = np.nan for bodypart in self.new_bodyparts: cols = pd.MultiIndex.from_product( [[self.scorer], [bodypart], ["x", "y"]], names=["scorer", "bodyparts", "coords"], ) index = pd.MultiIndex.from_tuples(self.relativeimagenames) frame = pd.DataFrame(a, columns=cols, index=index) self.dataFrame = pd.concat([self.dataFrame, frame], axis=1) ( self.figure, self.axes, self.canvas, self.toolbar, ) = self.image_panel.drawplot(self.img, img_name, self.iter, self.index, self.bodyparts, self.colormap) self.axes.callbacks.connect("xlim_changed", self.onZoom) self.axes.callbacks.connect("ylim_changed", self.onZoom) ( self.choiceBox, self.rdb, self.slider, self.checkBox, ) = self.choice_panel.addRadioButtons(self.bodyparts, self.file, self.markerSize) self.cidClick = self.canvas.mpl_connect("button_press_event", self.onClick) self.canvas.mpl_connect("button_release_event", self.onButtonRelease) self.buttonCounter = MainFrame.plot(self, self.img) self.checkBox.Bind(wx.EVT_CHECKBOX, self.activateSlider) self.slider.Bind(wx.EVT_SLIDER, self.OnSliderScroll) def nextImage(self, event): """ Moves to next image """ # Checks for the last image and disables the Next button if len(self.index) - self.iter == 1: self.next.Enable(False) return self.prev.Enable(True) # Checks if zoom/pan button is ON MainFrame.updateZoomPan(self) self.statusbar.SetStatusText("Working on folder: {}".format( os.path.split(str(self.dir))[-1])) self.rdb.SetSelection(0) self.file = 1 # Refreshing the button counter self.buttonCounter = [] MainFrame.saveEachImage(self) self.iter = self.iter + 1 if len(self.index) >= self.iter: self.updatedCoords = MainFrame.getLabels(self, self.iter) self.img = self.index[self.iter] img_name = Path(self.index[self.iter]).name self.figure.delaxes( self.figure.axes[1] ) # Removes the axes corresponding to the colorbar ( self.figure, self.axes, self.canvas, self.toolbar, ) = self.image_panel.drawplot( self.img, img_name, self.iter, self.index, self.bodyparts, self.colormap, keep_view=self.view_locked, ) self.axes.callbacks.connect("xlim_changed", self.onZoom) self.axes.callbacks.connect("ylim_changed", self.onZoom) self.buttonCounter = MainFrame.plot(self, self.img) self.cidClick = self.canvas.mpl_connect("button_press_event", self.onClick) self.canvas.mpl_connect("button_release_event", self.onButtonRelease) def prevImage(self, event): """ Checks the previous Image and enables user to move the annotations. """ # Checks for the first image and disables the Previous button if self.iter == 0: self.prev.Enable(False) return else: self.next.Enable(True) # Checks if zoom/pan button is ON MainFrame.updateZoomPan(self) self.statusbar.SetStatusText("Working on folder: {}".format( os.path.split(str(self.dir))[-1])) MainFrame.saveEachImage(self) self.buttonCounter = [] self.iter = self.iter - 1 self.rdb.SetSelection(0) self.img = self.index[self.iter] img_name = Path(self.index[self.iter]).name self.figure.delaxes(self.figure.axes[1] ) # Removes the axes corresponding to the colorbar self.figure, self.axes, self.canvas, self.toolbar = self.image_panel.drawplot( self.img, img_name, self.iter, self.index, self.bodyparts, self.colormap, keep_view=self.view_locked, ) self.axes.callbacks.connect("xlim_changed", self.onZoom) self.axes.callbacks.connect("ylim_changed", self.onZoom) self.buttonCounter = MainFrame.plot(self, self.img) self.cidClick = self.canvas.mpl_connect("button_press_event", self.onClick) self.canvas.mpl_connect("button_release_event", self.onButtonRelease) MainFrame.saveEachImage(self) def getLabels(self, img_index): """ Returns a list of x and y labels of the corresponding image index """ self.previous_image_points = [] for bpindex, bp in enumerate(self.bodyparts): image_points = [[ self.dataFrame[self.scorer][bp]["x"].values[self.iter], self.dataFrame[self.scorer][bp]["y"].values[self.iter], bp, bpindex, ]] self.previous_image_points.append(image_points) return self.previous_image_points def plot(self, img): """ Plots and call auxfun_drag class for moving and removing points. """ self.drs = [] self.updatedCoords = [] for bpindex, bp in enumerate(self.bodyparts): color = self.colormap(self.norm(self.colorIndex[bpindex])) self.points = [ self.dataFrame[self.scorer][bp]["x"].values[self.iter], self.dataFrame[self.scorer][bp]["y"].values[self.iter], ] circle = [ patches.Circle( (self.points[0], self.points[1]), radius=self.markerSize, fc=color, alpha=self.alpha, ) ] self.axes.add_patch(circle[0]) self.dr = auxfun_drag.DraggablePoint(circle[0], self.bodyparts[bpindex]) self.dr.connect() self.dr.coords = MainFrame.getLabels(self, self.iter)[bpindex] self.drs.append(self.dr) self.updatedCoords.append(self.dr.coords) if not np.isnan(self.points)[0]: self.buttonCounter.append(bpindex) self.figure.canvas.draw() return self.buttonCounter def saveEachImage(self): """ Saves data for each image """ for idx, bp in enumerate(self.updatedCoords): self.dataFrame.loc[self.relativeimagenames[self.iter]][ self.scorer, bp[0][-2], "x"] = bp[-1][0] self.dataFrame.loc[self.relativeimagenames[self.iter]][ self.scorer, bp[0][-2], "y"] = bp[-1][1] def saveDataSet(self, event): """ Saves the final dataframe """ self.statusbar.SetStatusText("File saved") MainFrame.saveEachImage(self) MainFrame.updateZoomPan(self) # Windows compatible self.dataFrame.sort_index(inplace=True) self.dataFrame = self.dataFrame.reindex( self.cfg["bodyparts"], axis=1, level=self.dataFrame.columns.names.index("bodyparts"), ) self.dataFrame.to_csv( os.path.join(self.dir, "CollectedData_" + self.scorer + ".csv")) self.dataFrame.to_hdf( os.path.join(self.dir, "CollectedData_" + self.scorer + ".h5"), "df_with_missing", ) def onChecked(self, event): self.cb = event.GetEventObject() if self.cb.GetValue(): self.slider.Enable(True) self.cidClick = self.canvas.mpl_connect("button_press_event", self.onClick) self.canvas.mpl_connect("button_release_event", self.onButtonRelease) else: self.slider.Enable(False)
class MainFrame(BaseFrame): """Contains the main GUI and button boxes""" def __init__(self, parent, config, video, shuffle, Dataframe, savelabeled, multianimal): super(MainFrame, self).__init__("DeepLabCut2.0 - Manual Outlier Frame Extraction", parent) ################################################################################################################################################### # Spliting the frame into top and bottom panels. Bottom panels contains the widgets. The top panel is for showing images and plotting! # topSplitter = wx.SplitterWindow(self) # # self.image_panel = ImagePanel(topSplitter, config,video,shuffle,Dataframe,self.gui_size) # self.widget_panel = WidgetPanel(topSplitter) # # topSplitter.SplitHorizontally(self.image_panel, self.widget_panel,sashPosition=self.gui_size[1]*0.83)#0.9 # topSplitter.SetSashGravity(1) # sizer = wx.BoxSizer(wx.VERTICAL) # sizer.Add(topSplitter, 1, wx.EXPAND) # self.SetSizer(sizer) # Spliting the frame into top and bottom panels. Bottom panels contains the widgets. The top panel is for showing images and plotting! topSplitter = wx.SplitterWindow(self) vSplitter = wx.SplitterWindow(topSplitter) self.image_panel = ImagePanel(vSplitter, self.gui_size) self.choice_panel = ScrollPanel(vSplitter) vSplitter.SplitVertically(self.image_panel, self.choice_panel, sashPosition=self.gui_size[0] * 0.8) vSplitter.SetSashGravity(1) self.widget_panel = WidgetPanel(topSplitter) topSplitter.SplitHorizontally(vSplitter, self.widget_panel, sashPosition=self.gui_size[1] * 0.83) # 0.9 topSplitter.SetSashGravity(1) sizer = wx.BoxSizer(wx.VERTICAL) sizer.Add(topSplitter, 1, wx.EXPAND) self.SetSizer(sizer) ################################################################################################################################################### # Add Buttons to the WidgetPanel and bind them to their respective functions. widgetsizer = wx.WrapSizer(orient=wx.HORIZONTAL) self.load_button_sizer = wx.BoxSizer(wx.VERTICAL) self.help_button_sizer = wx.BoxSizer(wx.VERTICAL) self.help = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Help") self.help_button_sizer.Add(self.help, 1, wx.ALL, 15) # widgetsizer.Add(self.help , 1, wx.ALL, 15) self.help.Bind(wx.EVT_BUTTON, self.helpButton) widgetsizer.Add(self.help_button_sizer, 1, wx.ALL, 0) self.grab = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Grab Frames") widgetsizer.Add(self.grab, 1, wx.ALL, 15) self.grab.Bind(wx.EVT_BUTTON, self.grabFrame) self.grab.Enable(True) widgetsizer.AddStretchSpacer(5) self.slider = wx.Slider( self.widget_panel, id=wx.ID_ANY, value=0, minValue=0, maxValue=1, size=(200, -1), style=wx.SL_HORIZONTAL | wx.SL_AUTOTICKS | wx.SL_LABELS, ) widgetsizer.Add(self.slider, 1, wx.ALL, 5) self.slider.Bind(wx.EVT_SLIDER, self.OnSliderScroll) widgetsizer.AddStretchSpacer(5) self.start_frames_sizer = wx.BoxSizer(wx.VERTICAL) self.end_frames_sizer = wx.BoxSizer(wx.VERTICAL) self.start_frames_sizer.AddSpacer(15) # self.startFrame = wx.SpinCtrl(self.widget_panel, value='0', size=(100, -1), min=0, max=120) self.startFrame = wx.SpinCtrl(self.widget_panel, value="0", size=(100, -1)) # ,style=wx.SP_VERTICAL) self.startFrame.Enable(False) self.start_frames_sizer.Add(self.startFrame, 1, wx.EXPAND | wx.ALIGN_LEFT, 15) start_text = wx.StaticText(self.widget_panel, label="Start Frame Index") self.start_frames_sizer.Add(start_text, 1, wx.EXPAND | wx.ALIGN_LEFT, 15) self.checkBox = wx.CheckBox(self.widget_panel, id=wx.ID_ANY, label="Range of frames") self.checkBox.Bind(wx.EVT_CHECKBOX, self.activate_frame_range) self.start_frames_sizer.Add(self.checkBox, 1, wx.EXPAND | wx.ALIGN_LEFT, 15) # self.end_frames_sizer.AddSpacer(15) self.endFrame = wx.SpinCtrl(self.widget_panel, value="1", size=(160, -1)) # , min=1, max=120) self.endFrame.Enable(False) self.end_frames_sizer.Add(self.endFrame, 1, wx.EXPAND | wx.ALIGN_LEFT, 15) end_text = wx.StaticText(self.widget_panel, label="Number of Frames") self.end_frames_sizer.Add(end_text, 1, wx.EXPAND | wx.ALIGN_LEFT, 15) self.updateFrame = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Update") self.end_frames_sizer.Add(self.updateFrame, 1, wx.EXPAND | wx.ALIGN_LEFT, 15) self.updateFrame.Bind(wx.EVT_BUTTON, self.updateSlider) self.updateFrame.Enable(False) widgetsizer.Add(self.start_frames_sizer, 1, wx.ALL, 0) widgetsizer.AddStretchSpacer(5) widgetsizer.Add(self.end_frames_sizer, 1, wx.ALL, 0) widgetsizer.AddStretchSpacer(15) self.quit = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Quit") widgetsizer.Add(self.quit, 1, wx.ALL, 15) self.quit.Bind(wx.EVT_BUTTON, self.quitButton) self.quit.Enable(True) self.widget_panel.SetSizer(widgetsizer) self.widget_panel.SetSizerAndFit(widgetsizer) # Variables initialization self.numberFrames = 0 self.currFrame = 0 self.figure = Figure() self.axes = self.figure.add_subplot(111) self.drs = [] self.extract_range_frame = False self.firstFrame = 0 self.Colorscheme = [] # Read confing file self.cfg = auxiliaryfunctions.read_config(config) self.Task = self.cfg["Task"] self.start = self.cfg["start"] self.stop = self.cfg["stop"] self.date = self.cfg["date"] self.trainFraction = self.cfg["TrainingFraction"] self.trainFraction = self.trainFraction[0] self.videos = self.cfg["video_sets"].keys() self.bodyparts = self.cfg["bodyparts"] self.colormap = plt.get_cmap(self.cfg["colormap"]) self.colormap = self.colormap.reversed() self.markerSize = self.cfg["dotsize"] self.alpha = self.cfg["alphavalue"] self.iterationindex = self.cfg["iteration"] self.cropping = self.cfg["cropping"] self.video_names = [Path(i).stem for i in self.videos] self.config_path = Path(config) self.video_source = Path(video).resolve() self.shuffle = shuffle self.Dataframe = Dataframe self.savelabeled = savelabeled self.multianimal = multianimal if self.multianimal: from deeplabcut.utils import auxfun_multianimal ( self.individual_names, self.uniquebodyparts, self.multianimalbodyparts, ) = auxfun_multianimal.extractindividualsandbodyparts(self.cfg) self.choiceBox, self.visualization_rdb = self.choice_panel.addRadioButtons( ) self.Colorscheme = visualization.get_cmap( len(self.individual_names), self.cfg["colormap"]) self.visualization_rdb.Bind(wx.EVT_RADIOBOX, self.clear_plot) # Read the video file self.vid = VideoWriter(str(self.video_source)) if self.cropping: self.vid.set_bbox(self.cfg["x1"], self.cfg["x2"], self.cfg["y1"], self.cfg["y2"]) self.filename = Path(self.video_source).name self.numberFrames = len(self.vid) self.strwidth = int(np.ceil(np.log10(self.numberFrames))) # Set the values of slider and range of frames self.startFrame.SetMax(self.numberFrames - 1) self.slider.SetMax(self.numberFrames - 1) self.endFrame.SetMax(self.numberFrames - 1) self.startFrame.Bind(wx.EVT_SPINCTRL, self.updateSlider) # wx.EVT_SPIN # Set the status bar self.statusbar.SetStatusText("Working on video: {}".format( self.filename)) # Adding the video file to the config file. if self.vid.name not in self.video_names: add.add_new_videos(self.config_path, [self.video_source]) self.update() self.plot_labels() self.widget_panel.Layout() def quitButton(self, event): """ Quits the GUI """ self.statusbar.SetStatusText("") dlg = wx.MessageDialog(None, "Are you sure?", "Quit!", wx.YES_NO | wx.ICON_WARNING) result = dlg.ShowModal() if result == wx.ID_YES: print("Quitting for now!") self.Destroy() def updateSlider(self, event): self.slider.SetValue(self.startFrame.GetValue()) self.startFrame.SetValue(self.slider.GetValue()) self.axes.clear() self.figure.delaxes(self.figure.axes[1]) self.grab.Bind(wx.EVT_BUTTON, self.grabFrame) self.currFrame = self.slider.GetValue() self.update() self.plot_labels() def activate_frame_range(self, event): """ Activates the frame range boxes """ self.checkSlider = event.GetEventObject() if self.checkSlider.GetValue(): self.extract_range_frame = True self.startFrame.Enable(True) self.startFrame.SetValue(self.slider.GetValue()) self.endFrame.Enable(True) self.updateFrame.Enable(True) self.grab.Enable(False) else: self.extract_range_frame = False self.startFrame.Enable(False) self.endFrame.Enable(False) self.updateFrame.Enable(False) self.grab.Enable(True) def line_select_callback(self, eclick, erelease): "eclick and erelease are the press and release events" self.new_x1, self.new_y1 = eclick.xdata, eclick.ydata self.new_x2, self.new_y2 = erelease.xdata, erelease.ydata def OnSliderScroll(self, event): """ Slider to scroll through the video """ self.axes.clear() self.figure.delaxes(self.figure.axes[1]) self.grab.Bind(wx.EVT_BUTTON, self.grabFrame) self.currFrame = self.slider.GetValue() self.startFrame.SetValue(self.currFrame) self.update() self.plot_labels() def update(self): """ Updates the image with the current slider index """ self.grab.Enable(True) self.grab.Bind(wx.EVT_BUTTON, self.grabFrame) self.figure, self.axes, self.canvas = self.image_panel.getfigure() self.vid.set_to_frame(self.currFrame) frame = self.vid.read_frame(crop=self.cropping) if frame is not None: frame = img_as_ubyte(frame) self.ax = self.axes.imshow(frame, cmap=self.colormap) self.axes.set_title( str( str(self.currFrame) + "/" + str(self.numberFrames - 1) + " " + self.filename)) self.figure.canvas.draw() else: print("Invalid frame") def chooseFrame(self): frame = img_as_ubyte(self.vid.read_frame(crop=self.cropping)) fname = Path(self.filename) output_path = self.config_path.parents[0] / "labeled-data" / fname.stem self.machinefile = os.path.join( str(output_path), "machinelabels-iter" + str(self.iterationindex) + ".h5") name = str(fname.stem) DF = self.Dataframe.iloc[[self.currFrame]] DF.index = [ os.path.join("labeled-data", name, "img" + str(index).zfill(self.strwidth) + ".png") for index in DF.index ] img_name = (str(output_path) + "/img" + str(self.currFrame).zfill( int(np.ceil(np.log10(self.numberFrames)))) + ".png") labeled_img_name = (str(output_path) + "/img" + str( self.currFrame).zfill(int(np.ceil(np.log10(self.numberFrames)))) + "labeled.png") # Check for it output path and a machine label file exist if output_path.exists() and Path(self.machinefile).is_file(): io.imsave(img_name, frame) if self.savelabeled: self.figure.savefig(labeled_img_name, bbox_inches="tight") Data = pd.read_hdf(self.machinefile) DataCombined = pd.concat([Data, DF]) DataCombined = DataCombined[~DataCombined.index.duplicated( keep="first")] DataCombined.to_hdf(self.machinefile, key="df_with_missing", mode="w") DataCombined.to_csv( os.path.join(str(output_path), "machinelabels.csv")) # If machine label file does not exist then create one elif output_path.exists() and not (Path(self.machinefile).is_file()): if self.savelabeled: self.figure.savefig(labeled_img_name, bbox_inches="tight") io.imsave(img_name, frame) # cv2.imwrite(img_name, frame) DF.to_hdf(self.machinefile, key="df_with_missing", mode="w") DF.to_csv(os.path.join(str(output_path), "machinelabels.csv")) else: print( "%s path not found. Please make sure that the video was added to the config file using the function 'deeplabcut.add_new_videos'.Quitting for now!" % output_path) self.Destroy() def grabFrame(self, event): """ Extracts the frame and saves in the current directory """ if self.extract_range_frame: num_frames_extract = self.endFrame.GetValue() for i in range(self.currFrame, self.currFrame + num_frames_extract): self.currFrame = i self.vid.set_to_frame(self.currFrame) self.chooseFrame() else: self.vid.set_to_frame(self.currFrame) self.chooseFrame() def clear_plot(self, event): self.figure.delaxes(self.figure.axes[1]) [p.remove() for p in reversed(self.axes.patches)] self.plot_labels() def plot_labels(self): """ Plots the labels of the analyzed video """ self.vid.set_to_frame(self.currFrame) frame = self.vid.read_frame() if frame is not None: divider = make_axes_locatable(self.axes) cax = divider.append_axes("right", size="5%", pad=0.05) if self.multianimal: # take into account of all the bodyparts for the colorscheme. Sort the bodyparts to have same order as in the config file self.all_bodyparts = np.array(self.multianimalbodyparts + self.uniquebodyparts) _, return_idx = np.unique(self.all_bodyparts, return_index=True) self.all_bodyparts = list( self.all_bodyparts[np.sort(return_idx)]) if (self.visualization_rdb.GetSelection() == 0 ): # i.e. for color scheme for individuals self.Colorscheme = visualization.get_cmap( len(self.individual_names), self.cfg["colormap"]) self.norm, self.colorIndex = self.image_panel.getColorIndices( frame, self.individual_names) cbar = self.figure.colorbar(self.ax, cax=cax, spacing="proportional", ticks=self.colorIndex) cbar.set_ticklabels(self.individual_names) else: # i.e. for color scheme for all bodyparts self.Colorscheme = visualization.get_cmap( len(self.all_bodyparts), self.cfg["colormap"]) self.norm, self.colorIndex = self.image_panel.getColorIndices( frame, self.all_bodyparts) cbar = self.figure.colorbar(self.ax, cax=cax, spacing="proportional", ticks=self.colorIndex) cbar.set_ticklabels(self.all_bodyparts) for ci, ind in enumerate(self.individual_names): col_idx = ( 0 ) # variable for iterating through the colorscheme for all bodyparts image_points = [] if ind == "single": if self.visualization_rdb.GetSelection() == 0: for c, bp in enumerate(self.uniquebodyparts): pts = self.Dataframe.xs( (ind, bp), level=("individuals", "bodyparts"), axis=1, ).values self.circle = patches.Circle( pts[self.currFrame, :2], radius=self.markerSize, fc=self.Colorscheme(ci), alpha=self.alpha, ) self.axes.add_patch(self.circle) else: for c, bp in enumerate(self.uniquebodyparts): pts = self.Dataframe.xs( (ind, bp), level=("individuals", "bodyparts"), axis=1, ).values self.circle = patches.Circle( pts[self.currFrame, :2], radius=self.markerSize, fc=self.Colorscheme(col_idx), alpha=self.alpha, ) self.axes.add_patch(self.circle) col_idx = col_idx + 1 else: if self.visualization_rdb.GetSelection() == 0: for c, bp in enumerate(self.multianimalbodyparts): pts = self.Dataframe.xs( (ind, bp), level=("individuals", "bodyparts"), axis=1, ).values self.circle = patches.Circle( pts[self.currFrame, :2], radius=self.markerSize, fc=self.Colorscheme(ci), alpha=self.alpha, ) self.axes.add_patch(self.circle) else: for c, bp in enumerate(self.multianimalbodyparts): pts = self.Dataframe.xs( (ind, bp), level=("individuals", "bodyparts"), axis=1, ).values self.circle = patches.Circle( pts[self.currFrame, :2], radius=self.markerSize, fc=self.Colorscheme(col_idx), alpha=self.alpha, ) self.axes.add_patch(self.circle) col_idx = col_idx + 1 self.figure.canvas.draw() else: self.norm, self.colorIndex = self.image_panel.getColorIndices( frame, self.bodyparts) cbar = self.figure.colorbar(self.ax, cax=cax, spacing="proportional", ticks=self.colorIndex) cbar.set_ticklabels(self.bodyparts) for bpindex, bp in enumerate(self.bodyparts): color = self.colormap(self.norm(self.colorIndex[bpindex])) self.points = [ self.Dataframe.xs((bp, "x"), level=(-2, -1), axis=1).values[self.currFrame], self.Dataframe.xs((bp, "y"), level=(-2, -1), axis=1).values[self.currFrame], 1.0, ] circle = [ patches.Circle( (self.points[0], self.points[1]), radius=self.markerSize, fc=color, alpha=self.alpha, ) ] self.axes.add_patch(circle[0]) self.figure.canvas.draw() else: print("Invalid frame") def helpButton(self, event): """ Opens Instructions """ wx.MessageBox( "1. Use the slider to select a frame in the entire video. \n\n2. Click Grab Frames button to save the specific frame.\ \n\n3. In the events where you need to extract a range of frames, then use the checkbox 'Range of frames' to select the starting frame index and the number of frames to extract.\ \n Click the update button to see the frame. Click Grab Frames to select the range of frames. \n\n Click OK to continue", "Instructions to use!", wx.OK | wx.ICON_INFORMATION, )
class MainFrame(BaseFrame): def __init__(self, parent, config, jump_unlabeled): super(MainFrame, self).__init__("DeepLabCut2.0 - Refinement ToolBox", parent) self.Bind(wx.EVT_CHAR_HOOK, self.OnKeyPressed) self.jump_unlabeled = jump_unlabeled ################################################################################################################################################### # Splitting the frame into top and bottom panels. Bottom panels contains the widgets. The top panel is for showing images and plotting! topSplitter = wx.SplitterWindow(self) vSplitter = wx.SplitterWindow(topSplitter) self.image_panel = ImagePanel(vSplitter, config, self.gui_size) self.choice_panel = ScrollPanel(vSplitter) # self.choice_panel.SetupScrolling(scroll_x=True, scroll_y=True, scrollToTop=False) # self.choice_panel.SetupScrolling(scroll_x=True, scrollToTop=False) vSplitter.SplitVertically(self.image_panel, self.choice_panel, sashPosition=self.gui_size[0] * 0.8) vSplitter.SetSashGravity(1) self.widget_panel = WidgetPanel(topSplitter) topSplitter.SplitHorizontally(vSplitter, self.widget_panel, sashPosition=self.gui_size[1] * 0.83) # 0.9 topSplitter.SetSashGravity(1) sizer = wx.BoxSizer(wx.VERTICAL) sizer.Add(topSplitter, 1, wx.EXPAND) self.SetSizer(sizer) ################################################################################################################################################### # Add Buttons to the WidgetPanel and bind them to their respective functions. widgetsizer = wx.WrapSizer(orient=wx.HORIZONTAL) self.load = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Load labels") widgetsizer.Add(self.load, 1, wx.ALL, 15) self.load.Bind(wx.EVT_BUTTON, self.browseDir) self.prev = wx.Button(self.widget_panel, id=wx.ID_ANY, label="<<Previous") widgetsizer.Add(self.prev, 1, wx.ALL, 15) self.prev.Bind(wx.EVT_BUTTON, self.prevImage) self.prev.Enable(False) self.next = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Next>>") widgetsizer.Add(self.next, 1, wx.ALL, 15) self.next.Bind(wx.EVT_BUTTON, self.nextImage) self.next.Enable(False) self.help = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Help") widgetsizer.Add(self.help, 1, wx.ALL, 15) self.help.Bind(wx.EVT_BUTTON, self.helpButton) self.help.Enable(True) self.zoom = wx.ToggleButton(self.widget_panel, label="Zoom") widgetsizer.Add(self.zoom, 1, wx.ALL, 15) self.zoom.Bind(wx.EVT_TOGGLEBUTTON, self.zoomButton) self.widget_panel.SetSizer(widgetsizer) self.zoom.Enable(False) self.home = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Home") widgetsizer.Add(self.home, 1, wx.ALL, 15) self.home.Bind(wx.EVT_BUTTON, self.homeButton) self.widget_panel.SetSizer(widgetsizer) self.home.Enable(False) self.pan = wx.ToggleButton(self.widget_panel, id=wx.ID_ANY, label="Pan") widgetsizer.Add(self.pan, 1, wx.ALL, 15) self.pan.Bind(wx.EVT_TOGGLEBUTTON, self.panButton) self.widget_panel.SetSizer(widgetsizer) self.pan.Enable(False) self.lock = wx.CheckBox(self.widget_panel, id=wx.ID_ANY, label="Lock View") widgetsizer.Add(self.lock, 1, wx.ALL, 15) self.lock.Bind(wx.EVT_CHECKBOX, self.lockChecked) self.widget_panel.SetSizer(widgetsizer) self.lock.Enable(False) self.save = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Save") widgetsizer.Add(self.save, 1, wx.ALL, 15) self.save.Bind(wx.EVT_BUTTON, self.saveDataSet) self.save.Enable(False) widgetsizer.AddStretchSpacer(15) self.quit = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Quit") widgetsizer.Add(self.quit, 1, wx.ALL, 15) self.quit.Bind(wx.EVT_BUTTON, self.quitButton) self.widget_panel.SetSizer(widgetsizer) self.widget_panel.SetSizerAndFit(widgetsizer) self.widget_panel.Layout() ############################################################################################################################### # Variable initialization self.currentDirectory = os.getcwd() self.index = [] self.iter = [] self.threshold = [] self.file = 0 self.updatedCoords = [] self.dataFrame = None self.drs = [] self.config_file = config cfg = auxiliaryfunctions.read_config(config) self.humanscorer = cfg["scorer"] self.move2corner = cfg["move2corner"] self.center = cfg["corner2move2"] self.colormap = plt.get_cmap(cfg["colormap"]) self.colormap = self.colormap.reversed() self.markerSize = cfg["dotsize"] self.alpha = cfg["alphavalue"] self.iterationindex = cfg["iteration"] self.project_path = cfg["project_path"] self.bodyparts = cfg["bodyparts"] self.threshold = 0.4 self.img_size = (10, 6) # (imgW, imgH) # width, height in inches. self.preview = False self.view_locked = False # Workaround for MAC - xlim and ylim changed events seem to be triggered too often so need to make sure that the # xlim and ylim have actually changed before turning zoom off self.prezoom_xlim = [] self.prezoom_ylim = [] # ########################################################################### # functions for button responses # ########################################################################### # BUTTONS FUNCTIONS FOR HOTKEYS def OnKeyPressed(self, event=None): if event.GetKeyCode() == wx.WXK_RIGHT: self.nextImage(event=None) elif event.GetKeyCode() == wx.WXK_LEFT: self.prevImage(event=None) elif event.GetKeyCode() == wx.WXK_BACK: pos_abs = event.GetPosition() inv = self.axes.transData.inverted() pos_rel = list(inv.transform(pos_abs)) y1, y2 = self.axes.get_ylim() pos_rel[1] = y1 - pos_rel[1] + y2 # Recall y-axis is inverted i = np.nanargmin([ self.calc_distance(*dp.point.center, *pos_rel) for dp in self.drs ]) closest_dp = self.drs[i] msg = wx.MessageBox( "Do you want to remove the label %s ?" % closest_dp.bodyParts, "Remove!", wx.YES_NO | wx.ICON_WARNING, ) if msg == 2: closest_dp.delete_data() def activateSlider(self, event): """ Activates the slider to increase the markersize """ self.checkSlider = event.GetEventObject() if self.checkSlider.GetValue(): self.activate_slider = True self.slider.Enable(True) MainFrame.updateZoomPan(self) else: self.slider.Enable(False) def OnSliderScroll(self, event): """ Adjust marker size for plotting the annotations """ self.markerSize = self.slider.GetValue() MainFrame.saveEachImage(self) MainFrame.updateZoomPan(self) self.updatedCoords = [] img_name = Path(*self.index[self.iter]).name # self.axes.clear() self.figure.delaxes(self.figure.axes[1]) self.figure, self.axes, self.canvas, self.toolbar = self.image_panel.drawplot( self.img, img_name, self.iter, self.index, self.threshold, self.bodyparts, self.colormap, self.preview, keep_view=True, ) self.axes.callbacks.connect("xlim_changed", self.onZoom) self.axes.callbacks.connect("ylim_changed", self.onZoom) MainFrame.plot(self, self.img) def browseDir(self, event): """ Show the DirDialog and ask the user to change the directory where machine labels are stored """ fname = str("machinelabels-iter" + str(self.iterationindex) + ".h5") self.statusbar.SetStatusText( "Looking for a folder to start refining...") if self.jump_unlabeled: cwd = str( auxiliaryfunctions.find_next_unlabeled_folder( self.config_file)) else: cwd = os.path.join(os.getcwd(), "labeled-data") # dlg = wx.FileDialog(self, "Choose the machinelabels file for current iteration.",cwd, "",wildcard=fname,style=wx.FD_OPEN | wx.FD_FILE_MUST_EXIST) platform.system() if platform.system() == "Darwin": dlg = wx.FileDialog( self, "Select the machinelabels-iterX.h5 file.", cwd, fname, wildcard="(*.h5)|*.h5", style=wx.FD_OPEN | wx.FD_FILE_MUST_EXIST, ) else: dlg = wx.FileDialog( self, "Select the machinelabels-iterX.h5 file.", cwd, "", wildcard=fname, style=wx.FD_OPEN | wx.FD_FILE_MUST_EXIST, ) if dlg.ShowModal() == wx.ID_OK: self.data_file = dlg.GetPath() self.dir = str(Path(self.data_file).parents[0]) self.fileName = str(Path(self.data_file).stem) self.load.Enable(False) self.next.Enable(True) self.save.Enable(True) self.zoom.Enable(True) self.pan.Enable(True) self.home.Enable(True) self.quit.Enable(True) self.lock.Enable(True) else: dlg.Destroy() self.Destroy() return dlg.Destroy() try: self.dataname = str(self.data_file) except: print("No machinelabels file found!") self.Destroy() self.statusbar.SetStatusText("Working on folder: {}".format( os.path.split(str(self.dir))[-1])) self.preview = True self.iter = 0 if os.path.isfile(self.dataname): self.Dataframe = pd.read_hdf(self.dataname) conversioncode.guarantee_multiindex_rows(self.Dataframe) self.Dataframe.sort_index(inplace=True) self.scorer = self.Dataframe.columns.get_level_values(0)[0] # bodyParts = self.Dataframe.columns.get_level_values(1) # _, idx = np.unique(bodyParts, return_index=True) # self.num_joints = len(self.bodyparts) # self.bodyparts = bodyParts[np.sort(idx)] self.index = list(self.Dataframe.iloc[:, 0].index) # Reading images self.img = os.path.join(self.project_path, *self.index[self.iter]) img_name = Path(self.img).name self.norm, self.colorIndex = self.image_panel.getColorIndices( self.img, self.bodyparts) # Adding Slider and Checkbox ( self.choiceBox, self.slider, self.checkBox, ) = self.choice_panel.addCheckBoxSlider(self.bodyparts, self.file, self.markerSize) self.slider.Bind(wx.EVT_SLIDER, self.OnSliderScroll) self.checkBox.Bind(wx.EVT_CHECKBOX, self.activateSlider) self.slider.Enable(False) # Show image # Setting axis title:dont want to show the threshold as it is not selected yet. ( self.figure, self.axes, self.canvas, self.toolbar, ) = self.image_panel.drawplot( self.img, img_name, self.iter, self.index, self.threshold, self.bodyparts, self.colormap, self.preview, ) self.axes.callbacks.connect("xlim_changed", self.onZoom) self.axes.callbacks.connect("ylim_changed", self.onZoom) instruction = wx.MessageBox( "1. Enter the likelihood threshold. \n\n2. Each prediction will be shown with a unique color. \n All the data points above the threshold will be marked as circle filled with a unique color. All the data points below the threshold will be marked with a hollow circle. \n\n3. Enable the checkbox to adjust the marker size. \n\n4. Hover your mouse over data points to see the labels and their likelihood. \n\n5. Left click and drag to move the data points. \n\n6. Middle click on any data point to remove it. Be careful, you cannot undo this step. \n Click once on the zoom button to zoom-in the image.The cursor will become cross, click and drag over a point to zoom in. \n Click on the zoom button again to disable the zooming function and recover the cursor. \n Use pan button to pan across the image while zoomed in. Use home button to go back to the full;default view. \n\n7. When finished click 'Save' to save all the changes. \n\n8. Click OK to continue", "User instructions", wx.OK | wx.ICON_INFORMATION, ) if instruction == 4: """ If ok is selected then the image is updated with the thresholded value of the likelihood """ textBox = wx.TextEntryDialog( self, "Select the likelihood threshold", caption="Enter the threshold", value="0.4", ) textBox.ShowModal() self.threshold = float(textBox.GetValue()) textBox.Destroy() self.img = os.path.join(self.project_path, *self.index[self.iter]) img_name = Path(self.img).name self.axes.clear() self.preview = False self.figure.delaxes( self.figure.axes[1] ) # Removes the axes corresponding to the colorbar ( self.figure, self.axes, self.canvas, self.toolbar, ) = self.image_panel.drawplot( self.img, img_name, self.iter, self.index, self.threshold, self.bodyparts, self.colormap, self.preview, ) self.axes.callbacks.connect("xlim_changed", self.onZoom) self.axes.callbacks.connect("ylim_changed", self.onZoom) MainFrame.plot(self, self.img) MainFrame.saveEachImage(self) else: self.figure.delaxes( self.figure.axes[1] ) # Removes the axes corresponding to the colorbar ( self.figure, self.axes, self.canvas, self.toolbar, ) = self.image_panel.drawplot( self.img, img_name, self.iter, self.index, self.threshold, self.bodyparts, self.colormap, self.preview, ) self.axes.callbacks.connect("xlim_changed", self.onZoom) self.axes.callbacks.connect("ylim_changed", self.onZoom) MainFrame.plot(self, self.img) MainFrame.saveEachImage(self) else: msg = wx.MessageBox( "No Machinelabels file found! Want to retry?", "Error!", wx.YES_NO | wx.ICON_WARNING, ) if msg == 2: self.load.Enable(True) self.next.Enable(False) self.save.Enable(False) def nextImage(self, event): """ Reads the next image and enables the user to move the annotations """ # Checks for the last image and disables the Next button if len(self.index) - self.iter == 1: self.next.Enable(False) return self.prev.Enable(True) # Checks if zoom/pan button is ON MainFrame.updateZoomPan(self) MainFrame.saveEachImage(self) self.statusbar.SetStatusText("Working on folder: {}".format( os.path.split(str(self.dir))[-1])) self.iter = self.iter + 1 if len(self.index) > self.iter: self.updatedCoords = [] self.img = os.path.join(self.project_path, *self.index[self.iter]) img_name = Path(self.img).name # Plotting self.figure.delaxes( self.figure.axes[1] ) # Removes the axes corresponding to the colorbar ( self.figure, self.axes, self.canvas, self.toolbar, ) = self.image_panel.drawplot( self.img, img_name, self.iter, self.index, self.threshold, self.bodyparts, self.colormap, self.preview, keep_view=self.view_locked, ) im = io.imread(self.img) self.axes.callbacks.connect("xlim_changed", self.onZoom) self.axes.callbacks.connect("ylim_changed", self.onZoom) if np.max(im) == 0: msg = wx.MessageBox( "Invalid image. Click Yes to remove", "Error!", wx.YES_NO | wx.ICON_WARNING, ) if msg == 2: self.Dataframe = self.Dataframe.drop(self.index[self.iter]) self.index = list(self.Dataframe.iloc[:, 0].index) self.iter = self.iter - 1 self.img = os.path.join(self.project_path, *self.index[self.iter]) img_name = Path(self.img).name ( self.figure, self.axes, self.canvas, self.toolbar, ) = self.image_panel.drawplot( self.img, img_name, self.iter, self.index, self.threshold, self.bodyparts, self.colormap, self.preview, keep_view=self.view_locked, ) self.axes.callbacks.connect("xlim_changed", self.onZoom) self.axes.callbacks.connect("ylim_changed", self.onZoom) MainFrame.plot(self, self.img) else: self.next.Enable(False) MainFrame.saveEachImage(self) def prevImage(self, event): """ Checks the previous Image and enables user to move the annotations. """ MainFrame.saveEachImage(self) # Checks if zoom/pan button is ON MainFrame.updateZoomPan(self) self.statusbar.SetStatusText("Working on folder: {}".format( os.path.split(str(self.dir))[-1])) self.next.Enable(True) self.iter = self.iter - 1 # Checks for the first image and disables the Previous button if self.iter == 0: self.prev.Enable(False) if self.iter >= 0: self.updatedCoords = [] # Reading Image self.img = os.path.join(self.project_path, *self.index[self.iter]) img_name = Path(self.img).name # Plotting self.figure.delaxes( self.figure.axes[1] ) # Removes the axes corresponding to the colorbar ( self.figure, self.axes, self.canvas, self.toolbar, ) = self.image_panel.drawplot( self.img, img_name, self.iter, self.index, self.threshold, self.bodyparts, self.colormap, self.preview, keep_view=self.view_locked, ) self.axes.callbacks.connect("xlim_changed", self.onZoom) self.axes.callbacks.connect("ylim_changed", self.onZoom) MainFrame.plot(self, self.img) else: self.prev.Enable(False) MainFrame.saveEachImage(self) def quitButton(self, event): """ Quits the GUI """ self.statusbar.SetStatusText("") dlg = wx.MessageDialog(None, "Are you sure?", "Quit!", wx.YES_NO | wx.ICON_WARNING) result = dlg.ShowModal() if result == wx.ID_YES: print( "Closing... The refined labels are stored in a subdirectory under labeled-data. Use the function 'merge_datasets' to augment the training dataset, and then re-train a network using create_training_dataset followed by train_network!" ) self.Destroy() else: self.save.Enable(True) def helpButton(self, event): """ Opens Instructions """ self.statusbar.SetStatusText("Help") # Checks if zoom/pan button is ON MainFrame.updateZoomPan(self) wx.MessageBox( "1. Enter the likelihood threshold. \n\n2. All the data points above the threshold will be marked as circle filled with a unique color. All the data points below the threshold will be marked with a hollow circle. \n\n3. Enable the checkbox to adjust the marker size (you will not be able to zoom/pan/home until the next frame). \n\n4. Hover your mouse over data points to see the labels and their likelihood. \n\n5. LEFT click+drag to move the data points. \n\n6. MIDDLE click on any data point to remove it. Be careful, you cannot undo this step! \n Click once on the zoom button to zoom-in the image. The cursor will become cross, click and drag over a point to zoom in. \n Click on the zoom button again to disable the zooming function and recover the cursor. \n Use pan button to pan across the image while zoomed in. Use home button to go back to the full default view. \n\n7. When finished click 'Save' to save all the changes. \n\n8. Click OK to continue", "User instructions", wx.OK | wx.ICON_INFORMATION, ) def onChecked(self, event): MainFrame.saveEachImage(self) self.cb = event.GetEventObject() if self.cb.GetValue(): self.slider.Enable(True) else: self.slider.Enable(False) def check_labels(self): print("Checking labels if they are outside the image") for i in self.Dataframe.index: image_name = os.path.join(self.project_path, *i) im = PIL.Image.open(image_name) width, height = im.size for bpindex, bp in enumerate(self.bodyparts): testCondition = ( self.Dataframe.loc[i, (self.scorer, bp, "x")] > width or self.Dataframe.loc[i, (self.scorer, bp, "x")] < 0 or self.Dataframe.loc[i, (self.scorer, bp, "y")] > height or self.Dataframe.loc[i, (self.scorer, bp, "y")] < 0) if testCondition: print("Found %s outside the image %s.Setting it to NaN" % (bp, i)) self.Dataframe.loc[i, (self.scorer, bp, "x")] = np.nan self.Dataframe.loc[i, (self.scorer, bp, "y")] = np.nan return self.Dataframe def saveDataSet(self, event): MainFrame.saveEachImage(self) # Checks if zoom/pan button is ON MainFrame.updateZoomPan(self) self.statusbar.SetStatusText("File saved") self.Dataframe = MainFrame.check_labels(self) # Overwrite machine label file self.Dataframe.to_hdf(self.dataname, key="df_with_missing", mode="w") self.Dataframe.columns.set_levels( [self.scorer.replace(self.scorer, self.humanscorer)], level=0, inplace=True) self.Dataframe = self.Dataframe.drop("likelihood", axis=1, level=2) if Path(self.dir, "CollectedData_" + self.humanscorer + ".h5").is_file(): print( "A training dataset file is already found for this video. The refined machine labels are merged to this data!" ) DataU1 = pd.read_hdf( os.path.join(self.dir, "CollectedData_" + self.humanscorer + ".h5")) # combine datasets Original Col. + corrected machinefiles: DataCombined = pd.concat([self.Dataframe, DataU1]) # Now drop redundant ones keeping the first one [this will make sure that the refined machine file gets preference] DataCombined = DataCombined[~DataCombined.index.duplicated( keep="first")] """ if len(self.droppedframes)>0: #i.e. frames were dropped/corrupt. also remove them from original file (if they exist!) for fn in self.droppedframes: try: DataCombined.drop(fn,inplace=True) except KeyError: pass """ DataCombined.sort_index(inplace=True) DataCombined.to_hdf( os.path.join(self.dir, "CollectedData_" + self.humanscorer + ".h5"), key="df_with_missing", mode="w", ) DataCombined.to_csv( os.path.join(self.dir, "CollectedData_" + self.humanscorer + ".csv")) else: self.Dataframe.sort_index(inplace=True) self.Dataframe.to_hdf( os.path.join(self.dir, "CollectedData_" + self.humanscorer + ".h5"), key="df_with_missing", mode="w", ) self.Dataframe.to_csv( os.path.join(self.dir, "CollectedData_" + self.humanscorer + ".csv")) self.next.Enable(False) self.prev.Enable(False) self.slider.Enable(False) self.checkBox.Enable(False) nextFilemsg = wx.MessageBox( "File saved. Do you want to refine another file?", "Repeat?", wx.YES_NO | wx.ICON_INFORMATION, ) if nextFilemsg == 2: self.file = 1 # self.buttonCounter = [] self.updatedCoords = [] self.dataFrame = None self.prev.Enable(False) # self.bodyparts = [] self.figure.delaxes(self.figure.axes[1]) self.axes.clear() self.choiceBox.Clear(True) MainFrame.updateZoomPan(self) MainFrame.browseDir(self, event) # ########################################################################### # Other functions # ########################################################################### def saveEachImage(self): """ Updates the dataframe for the current image with the new datapoints """ for bpindex, bp in enumerate(self.bodyparts): if self.updatedCoords[bpindex]: self.Dataframe.loc[self.Dataframe.index[self.iter], (self.scorer, bp, "x")] = self.updatedCoords[bpindex][-1][0] self.Dataframe.loc[self.Dataframe.index[self.iter], (self.scorer, bp, "y")] = self.updatedCoords[bpindex][-1][1] def getLabels(self, img_index): """ Returns a list of x and y labels of the corresponding image index """ self.previous_image_points = [] for bpindex, bp in enumerate(self.bodyparts): image_points = [[ self.Dataframe[self.scorer][bp]["x"].values[self.iter], self.Dataframe[self.scorer][bp]["y"].values[self.iter], bp, bpindex, ]] self.previous_image_points.append(image_points) return self.previous_image_points def plot(self, im): """ Plots and call auxfun_drag class for moving and removing points. """ # small hack in case there are any 0 intensity images! im = io.imread(im) maxIntensity = np.max(im) if maxIntensity == 0: maxIntensity = np.max(im) + 255 self.drs = [] for bpindex, bp in enumerate(self.bodyparts): color = self.colormap(self.norm(self.colorIndex[bpindex])) if "CollectedData_" in self.fileName: self.points = [ self.Dataframe[self.scorer][bp]["x"].values[self.iter], self.Dataframe[self.scorer][bp]["y"].values[self.iter], 1.0, ] self.likelihood = self.points[2] else: self.points = [ self.Dataframe[self.scorer][bp]["x"].values[self.iter], self.Dataframe[self.scorer][bp]["y"].values[self.iter], self.Dataframe[self.scorer][bp]["likelihood"].values[ self.iter], ] self.likelihood = self.points[2] if self.move2corner: ny, nx = np.shape(im)[0], np.shape(im)[1] if self.points[0] > nx or self.points[0] < 0: self.points[0] = self.center[0] if self.points[1] > ny or self.points[1] < 0: self.points[1] = self.center[1] if (not ("CollectedData_" in self.fileName) and self.likelihood < self.threshold): circle = [ patches.Circle( (self.points[0], self.points[1]), radius=self.markerSize, facecolor="None", edgecolor=color, ) ] else: circle = [ patches.Circle( (self.points[0], self.points[1]), radius=self.markerSize, fc=color, alpha=self.alpha, ) ] self.axes.add_patch(circle[0]) self.dr = auxfun_drag.DraggablePoint(circle[0], bp, likelihood=self.likelihood) self.dr.connect() self.dr.coords = MainFrame.getLabels(self, self.iter)[bpindex] self.drs.append(self.dr) self.updatedCoords.append(self.dr.coords) self.figure.canvas.draw()
class MainFrame(BaseFrame): def __init__(self, parent, config, slider_width=25): super(MainFrame, self).__init__( "DeepLabCut2.0 - Manual Frame Extraction", parent, ) ################################################################################################################################################### # Spliting the frame into top and bottom panels. Bottom panels contains the widgets. The top panel is for showing images and plotting! topSplitter = wx.SplitterWindow(self) self.image_panel = ImagePanel(topSplitter, config, self.gui_size) self.widget_panel = WidgetPanel(topSplitter) topSplitter.SplitHorizontally(self.image_panel, self.widget_panel, sashPosition=self.gui_size[1] * 0.83) # 0.9 topSplitter.SetSashGravity(1) sizer = wx.BoxSizer(wx.VERTICAL) sizer.Add(topSplitter, 1, wx.EXPAND) self.SetSizer(sizer) ################################################################################################################################################### # Add Buttons to the WidgetPanel and bind them to their respective functions. widgetsizer = wx.WrapSizer(orient=wx.HORIZONTAL) self.load = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Load Video") widgetsizer.Add(self.load, 1, wx.ALL, 15) self.load.Bind(wx.EVT_BUTTON, self.browseDir) self.help = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Help") widgetsizer.Add(self.help, 1, wx.ALL, 15) self.help.Bind(wx.EVT_BUTTON, self.helpButton) self.grab = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Grab Frames") widgetsizer.Add(self.grab, 1, wx.ALL, 15) self.grab.Bind(wx.EVT_BUTTON, self.grabFrame) self.grab.Enable(False) widgetsizer.AddStretchSpacer(5) size_x = round(self.gui_size[0] * (slider_width / 100), 0) self.slider = wx.Slider( self.widget_panel, id=wx.ID_ANY, value=0, minValue=0, maxValue=1, size=(size_x, -1), style=wx.SL_HORIZONTAL | wx.SL_AUTOTICKS | wx.SL_LABELS, ) widgetsizer.Add(self.slider, 1, wx.ALL, 5) self.slider.Hide() widgetsizer.AddStretchSpacer(5) self.start_frames_sizer = wx.BoxSizer(wx.VERTICAL) self.end_frames_sizer = wx.BoxSizer(wx.VERTICAL) self.start_frames_sizer.AddSpacer(15) self.startFrame = wx.SpinCtrl(self.widget_panel, value="0", size=(100, -1), min=0, max=120) self.startFrame.Bind(wx.EVT_SPINCTRL, self.updateSlider) self.startFrame.Enable(False) self.start_frames_sizer.Add(self.startFrame, 1, wx.EXPAND | wx.ALIGN_LEFT, 15) start_text = wx.StaticText(self.widget_panel, label="Start Frame Index") self.start_frames_sizer.Add(start_text, 1, wx.EXPAND | wx.ALIGN_LEFT, 15) self.checkBox = wx.CheckBox(self.widget_panel, id=wx.ID_ANY, label="Range of frames") self.checkBox.Bind(wx.EVT_CHECKBOX, self.activate_frame_range) self.start_frames_sizer.Add(self.checkBox, 1, wx.EXPAND | wx.ALIGN_LEFT, 15) # self.end_frames_sizer.AddSpacer(15) self.endFrame = wx.SpinCtrl(self.widget_panel, value="1", size=(160, -1), min=1, max=120) self.endFrame.Enable(False) self.end_frames_sizer.Add(self.endFrame, 1, wx.EXPAND | wx.ALIGN_LEFT, 15) end_text = wx.StaticText(self.widget_panel, label="Number of Frames") self.end_frames_sizer.Add(end_text, 1, wx.EXPAND | wx.ALIGN_LEFT, 15) self.updateFrame = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Update") self.end_frames_sizer.Add(self.updateFrame, 1, wx.EXPAND | wx.ALIGN_LEFT, 15) self.updateFrame.Bind(wx.EVT_BUTTON, self.updateSlider) self.updateFrame.Enable(False) widgetsizer.Add(self.start_frames_sizer, 1, wx.ALL, 0) widgetsizer.AddStretchSpacer(5) widgetsizer.Add(self.end_frames_sizer, 1, wx.ALL, 0) widgetsizer.AddStretchSpacer(15) self.quit = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Quit") widgetsizer.Add(self.quit, 1, wx.ALL, 15) self.quit.Bind(wx.EVT_BUTTON, self.quitButton) self.quit.Enable(True) # Hiding these widgets and show them once the video is loaded self.start_frames_sizer.ShowItems(show=False) self.end_frames_sizer.ShowItems(show=False) self.widget_panel.SetSizer(widgetsizer) self.widget_panel.SetSizerAndFit(widgetsizer) self.widget_panel.Layout() # Variables initialization self.numberFrames = 0 self.currFrame = 0 self.figure = Figure() self.axes = self.figure.add_subplot(111) self.drs = [] self.cfg = auxiliaryfunctions.read_config(config) self.Task = self.cfg["Task"] self.start = self.cfg["start"] self.stop = self.cfg["stop"] self.date = self.cfg["date"] self.trainFraction = self.cfg["TrainingFraction"] self.trainFraction = self.trainFraction[0] self.videos = list( self.cfg.get("video_sets_original") or self.cfg["video_sets"]) self.bodyparts = self.cfg["bodyparts"] self.colormap = plt.get_cmap(self.cfg["colormap"]) self.colormap = self.colormap.reversed() self.markerSize = self.cfg["dotsize"] self.alpha = self.cfg["alphavalue"] self.video_names = [Path(i).stem for i in self.videos] self.config_path = Path(config) self.extract_range_frame = False self.extract_from_analyse_video = False def quitButton(self, event): """ Quits the GUI """ self.statusbar.SetStatusText("") dlg = wx.MessageDialog(None, "Are you sure?", "Quit!", wx.YES_NO | wx.ICON_WARNING) result = dlg.ShowModal() if result == wx.ID_YES: print("Quitting for now!") self.Destroy() def updateSlider(self, event): self.slider.SetValue(self.startFrame.GetValue()) self.currFrame = self.slider.GetValue() if self.extract_from_analyse_video: self.figure.delaxes(self.figure.axes[1]) self.plot_labels() self.update() def activate_frame_range(self, event): """ Activates the frame range boxes """ self.checkSlider = event.GetEventObject() if self.checkSlider.GetValue(): self.extract_range_frame = True self.startFrame.Enable(True) self.startFrame.SetValue(self.slider.GetValue()) self.endFrame.Enable(True) self.updateFrame.Enable(True) self.grab.Enable(False) else: self.extract_range_frame = False self.startFrame.Enable(False) self.endFrame.Enable(False) self.updateFrame.Enable(False) self.grab.Enable(True) def line_select_callback(self, eclick, erelease): "eclick and erelease are the press and release events" self.new_x1, self.new_y1 = eclick.xdata, eclick.ydata self.new_x2, self.new_y2 = erelease.xdata, erelease.ydata def CheckCropping(self): """ Display frame at time "time" for video to check if cropping is fine. Select ROI of interest by adjusting values in myconfig.py USAGE for cropping: clip.crop(x1=None, y1=None, x2=None, y2=None, width=None, height=None, x_center=None, y_center=None) Returns a new clip in which just a rectangular subregion of the original clip is conserved. x1,y1 indicates the top left corner and x2,y2 is the lower right corner of the cropped region. All coordinates are in pixels. Float numbers are accepted. """ videosource = self.video_source try: self.x1 = int( self.cfg["video_sets"][videosource]["crop"].split(",")[0]) self.x2 = int( self.cfg["video_sets"][videosource]["crop"].split(",")[1]) self.y1 = int( self.cfg["video_sets"][videosource]["crop"].split(",")[2]) self.y2 = int( self.cfg["video_sets"][videosource]["crop"].split(",")[3]) except KeyError: self.x1, self.x2, self.y1, self.y2 = map( int, self.cfg["video_sets_original"][videosource] ["crop"].split(",")) if self.cropping: # Select ROI of interest by drawing a rectangle self.cid = RectangleSelector( self.axes, self.line_select_callback, drawtype="box", useblit=False, button=[1], minspanx=5, minspany=5, spancoords="pixels", interactive=True, ) self.canvas.mpl_connect("key_press_event", self.cid) def OnSliderScroll(self, event): """ Slider to scroll through the video """ self.axes.clear() self.grab.Bind(wx.EVT_BUTTON, self.grabFrame) self.currFrame = self.slider.GetValue() self.startFrame.SetValue(self.currFrame) self.update() def is_crop_ok(self, event): """ Checks if the cropping is ok """ self.grab.SetLabel("Grab Frames") self.grab.Bind(wx.EVT_BUTTON, self.grabFrame) self.slider.Show() self.start_frames_sizer.ShowItems(show=True) self.end_frames_sizer.ShowItems(show=True) self.widget_panel.Layout() self.slider.SetMax(self.numberFrames) self.startFrame.SetMax(self.numberFrames - 1) self.endFrame.SetMax(self.numberFrames) self.x1 = int(self.new_x1) self.x2 = int(self.new_x2) self.y1 = int(self.new_y1) self.y2 = int(self.new_y2) self.canvas.mpl_disconnect(self.cid) self.axes.clear() self.currFrame = self.slider.GetValue() self.update() # Update the config.yaml file self.cfg["video_sets"][self.video_source] = { "crop": ", ".join(map(str, [self.x1, self.x2, self.y1, self.y2])) } auxiliaryfunctions.write_config(self.config_path, self.cfg) def browseDir(self, event): """ Show the File Dialog and ask the user to select the video file """ self.statusbar.SetStatusText( "Looking for a video to start extraction..") dlg = wx.FileDialog(self, "SELECT A VIDEO", os.getcwd(), "", "*.*", wx.FD_OPEN) if dlg.ShowModal() == wx.ID_OK: self.video_source_original = dlg.GetPath() self.video_source = str(Path(self.video_source_original).resolve()) self.load.Enable(False) else: pass dlg.Destroy() self.Close(True) dlg.Destroy() selectedvideo = Path(self.video_source) self.statusbar.SetStatusText("Working on video: {}".format( os.path.split(str(selectedvideo))[-1])) if str(selectedvideo.stem) in self.video_names: self.grab.Enable(True) self.vid = cv2.VideoCapture(self.video_source) self.videoPath = os.path.dirname(self.video_source) self.filename = Path(self.video_source).name self.numberFrames = int(self.vid.get(cv2.CAP_PROP_FRAME_COUNT)) # Checks if the video is corrupt. if not self.vid.isOpened(): msg = wx.MessageBox( "Invalid Video file!Do you want to retry?", "Error!", wx.YES_NO | wx.ICON_WARNING, ) if msg == 2: self.load.Enable(True) MainFrame.browseDir(self, event) else: self.Destroy() self.slider.Bind(wx.EVT_SLIDER, self.OnSliderScroll) self.update() cropMsg = wx.MessageBox( "Do you want to crop the frames?", "Want to crop?", wx.YES_NO | wx.ICON_INFORMATION, ) if cropMsg == 2: self.cropping = True self.grab.SetLabel("Set cropping parameters") self.grab.Bind(wx.EVT_BUTTON, self.is_crop_ok) self.widget_panel.Layout() self.basefolder = "data-" + self.Task + "/" MainFrame.CheckCropping(self) else: self.cropping = False self.slider.Show() self.start_frames_sizer.ShowItems(show=True) self.end_frames_sizer.ShowItems(show=True) self.widget_panel.Layout() self.slider.SetMax(self.numberFrames - 1) self.startFrame.SetMax(self.numberFrames - 1) self.endFrame.SetMax(self.numberFrames - 1) else: wx.MessageBox( "Video file is not in config file. Use add function to add this video in the config file and retry!", "Error!", wx.OK | wx.ICON_WARNING, ) self.Close(True) def update(self): """ Updates the image with the current slider index """ self.grab.Enable(True) self.grab.Bind(wx.EVT_BUTTON, self.grabFrame) self.figure, self.axes, self.canvas = self.image_panel.getfigure() self.vid.set(1, self.currFrame) ret, frame = self.vid.read() if ret: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) self.ax = self.axes.imshow(frame) self.axes.set_title( str( str(self.currFrame) + "/" + str(self.numberFrames - 1) + " " + self.filename)) self.figure.canvas.draw() def chooseFrame(self): ret, frame = self.vid.read() fname = Path(self.filename) output_path = self.config_path.parents[0] / "labeled-data" / fname.stem if output_path.exists(): frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = img_as_ubyte(frame) img_name = (str(output_path) + "/img" + str(self.currFrame).zfill( int(np.ceil(np.log10(self.numberFrames)))) + ".png") if self.cropping: crop_img = frame[self.y1:self.y2, self.x1:self.x2] cv2.imwrite(img_name, cv2.cvtColor(crop_img, cv2.COLOR_RGB2BGR)) else: cv2.imwrite(img_name, cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) else: print( "%s path not found. Please make sure that the video was added to the config file using the function 'deeplabcut.add_new_videos'." % output_path) def grabFrame(self, event): """ Extracts the frame and saves in the current directory """ num_frames_extract = self.endFrame.GetValue() for i in range(self.currFrame, self.currFrame + num_frames_extract): self.currFrame = i self.vid.set(1, self.currFrame) self.chooseFrame() self.vid.set(1, self.currFrame) self.chooseFrame() def plot_labels(self): """ Plots the labels of the analyzed video """ self.vid.set(1, self.currFrame) ret, frame = self.vid.read() if ret: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) self.norm = mcolors.Normalize(vmin=np.min(frame), vmax=np.max(frame)) self.colorIndex = np.linspace(np.min(frame), np.max(frame), len(self.bodyparts)) divider = make_axes_locatable(self.axes) cax = divider.append_axes("right", size="5%", pad=0.05) cbar = self.figure.colorbar(self.ax, cax=cax, spacing="proportional", ticks=self.colorIndex) cbar.set_ticklabels(self.bodyparts) for bpindex, bp in enumerate(self.bodyparts): color = self.colormap(self.norm(self.colorIndex[bpindex])) self.points = [ self.Dataframe[self.scorer][bp]["x"].values[ self.currFrame], self.Dataframe[self.scorer][bp]["y"].values[ self.currFrame], 1.0, ] circle = [ patches.Circle( (self.points[0], self.points[1]), radius=self.markerSize, fc=color, alpha=self.alpha, ) ] self.axes.add_patch(circle[0]) self.figure.canvas.draw() def helpButton(self, event): """ Opens Instructions """ wx.MessageBox( "1. Use the Load Video button to load a video. Use the slider to select a frame in the entire video. The number mentioned on the top of the slider represents the frame index. \n\n2. Click Grab Frames button to save the specific frame.\n\n3. In events where you need to extract a range of frames, then use the checkbox Range of frames to select the start frame index and number of frames to extract. Click the update button to see the start frame index. Click Grab Frames to select the range of frames. \n\n Click OK to continue", "Instructions to use!", wx.OK | wx.ICON_INFORMATION, )
class MainFrame(BaseFrame): def __init__(self, parent, config, image): super(MainFrame, self).__init__( "DeepLabCut2.0 - Select Crop Parameters", parent, ) ################################################################################################################################################### # Spliting the frame into top and bottom panels. Bottom panels contains the widgets. The top panel is for showing images and plotting! topSplitter = wx.SplitterWindow(self) self.image_panel = BasePanel(topSplitter, config, self.gui_size) self.widget_panel = WidgetPanel(topSplitter) topSplitter.SplitHorizontally( self.image_panel, self.widget_panel, sashPosition=self.gui_size[1] * 0.83 ) # 0.9 topSplitter.SetSashGravity(1) sizer = wx.BoxSizer(wx.VERTICAL) sizer.Add(topSplitter, 1, wx.EXPAND) self.SetSizer(sizer) ################################################################################################################################################### # Add Buttons to the WidgetPanel and bind them to their respective functions. widgetsizer = wx.WrapSizer(orient=wx.HORIZONTAL) self.help = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Help") widgetsizer.Add(self.help, 1, wx.ALL, 15) self.help.Bind(wx.EVT_BUTTON, self.helpButton) self.quit = wx.Button( self.widget_panel, id=wx.ID_ANY, label="Save parameters and Quit" ) widgetsizer.Add(self.quit, 1, wx.ALL, 15) self.quit.Bind(wx.EVT_BUTTON, self.quitButton) self.widget_panel.SetSizer(widgetsizer) self.widget_panel.SetSizerAndFit(widgetsizer) self.widget_panel.Layout() # Variables initialization self.image = image self.coords = [] self.figure = Figure() self.axes = self.figure.add_subplot(111) # self.cfg = auxiliaryfunctions.read_config(config) MainFrame.show_image(self) def quitButton(self, event): """ Quits the GUI """ # self.statusbar.SetStatusText("") # dlg = wx.MessageDialog(None,"Are you sure?", "Quit!",wx.YES_NO | wx.ICON_WARNING) # result = dlg.ShowModal() # if result == wx.ID_YES: self.Destroy() def show_image(self): self.figure, self.axes, self.canvas = self.image_panel.getfigure() self.ax = self.axes.imshow(self.image) self.figure.canvas.draw() self.cid = RectangleSelector( self.axes, self.line_select_callback, drawtype="box", useblit=False, button=[1], minspanx=5, minspany=5, spancoords="pixels", interactive=True, ) self.canvas.mpl_connect("key_press_event", self.cid) def line_select_callback(self, eclick, erelease): "eclick and erelease are the press and release events" new_x1, new_y1 = eclick.xdata, eclick.ydata new_x2, new_y2 = erelease.xdata, erelease.ydata coords = [ str(int(new_x1)), str(int(new_x2)), str(int(new_y1)), str(int(new_y2)), ] self.coords = coords def helpButton(self, event): """ Opens Instructions """ wx.MessageBox( "1. Use left click to select the region of interest. A red box will be drawn around the selected region. \n\n2. Use the corner points to expand the box and center to move the box around the image. \n\n3. Click " "Save parameters and Quit" " to save the croppeing parameters and close the GUI. \n\n Click OK to continue", "Instructions to use!", wx.OK | wx.ICON_INFORMATION, )