Beispiel #1
0
def test_reader_wrong_inputs(tmp_path):
    with pytest.raises(ValueError):
        VideoWriter(str(tmp_path))
    fake_vid = tmp_path / "fake.avi"
    fake_vid.write_bytes(b"42")
    with pytest.raises(IOError):
        VideoWriter(str(fake_vid))
def _create_video_from_tracks(video,
                              tracks,
                              destfolder,
                              output_name,
                              pcutoff,
                              scale=1):
    import subprocess
    from tqdm import tqdm

    if not os.path.isdir(destfolder):
        os.mkdir(destfolder)

    vid = VideoWriter(video)
    nframes = len(vid)
    strwidth = int(np.ceil(np.log10(nframes)))  # width for strings
    nx, ny = vid.dimensions
    # cropping!
    X2 = nx  # 1600
    X1 = 0
    # nx=X2-X1
    numtracks = len(tracks.keys()) - 1
    trackids = [t for t in tracks.keys() if t != "header"]
    cc = np.random.rand(numtracks + 1, 3)
    fig, ax = visualization.prepare_figure_axes(nx, ny, scale)
    im = ax.imshow(np.zeros((ny, nx)))
    markers = sum([ax.plot([], [], ".", c=c) for c in cc], [])
    for index in tqdm(range(nframes)):
        vid.set_to_frame(index)
        imname = "frame" + str(index).zfill(strwidth)
        image_output = os.path.join(destfolder, imname + ".png")
        frame = vid.read_frame()
        if frame is not None and not os.path.isfile(image_output):
            im.set_data(frame[:, X1:X2])
            for n, trackid in enumerate(trackids):
                if imname in tracks[trackid]:
                    x, y, p = tracks[trackid][imname].reshape((-1, 3)).T
                    markers[n].set_data(x[p > pcutoff], y[p > pcutoff])
                else:
                    markers[n].set_data([], [])
            fig.subplots_adjust(left=0,
                                bottom=0,
                                right=1,
                                top=1,
                                wspace=0,
                                hspace=0)
            plt.savefig(image_output)

    outputframerate = 30
    os.chdir(destfolder)

    subprocess.call([
        "ffmpeg",
        "-framerate",
        str(vid.fps),
        "-i",
        f"frame%0{strwidth}d.png",
        "-r",
        str(outputframerate),
        output_name,
    ])
Beispiel #3
0
def test_writer_rescale(tmp_path, video_clip, target_height):
    file = video_clip.rescale(width=-1, height=target_height, dest_folder=str(tmp_path))
    vid = VideoWriter(file)
    assert vid.height == target_height
    # Verify the aspect ratio is preserved
    ar = video_clip.height / target_height
    assert vid.width == pytest.approx(video_clip.width // ar, abs=1)
Beispiel #4
0
def test_writer_split(tmp_path, video_clip):
    with pytest.raises(ValueError):
        video_clip.split(1)
    n_splits = 3
    clips = video_clip.split(n_splits, dest_folder=str(tmp_path))
    assert len(clips) == n_splits
    vid = VideoWriter(clips[0])
    assert pytest.approx(len(vid), abs=1) == len(video_clip) // n_splits
def AnalyzeMultiAnimalVideo(
    video,
    DLCscorer,
    trainFraction,
    cfg,
    dlc_cfg,
    sess,
    inputs,
    outputs,
    pdindex,
    save_as_csv,
    destfolder=None,
    c_engine=False,
    robust_nframes=False,
):
    """ Helper function for analyzing a video with multiple individuals """

    print("Starting to analyze % ", video)
    vname = Path(video).stem
    videofolder = str(Path(video).parents[0])
    if destfolder is None:
        destfolder = videofolder
    auxiliaryfunctions.attempttomakefolder(destfolder)
    dataname = os.path.join(destfolder, vname + DLCscorer + ".h5")

    if os.path.isfile(dataname.split(".h5")[0] + "_full.pickle"):
        print("Video already analyzed!", dataname)
    else:
        print("Loading ", video)
        vid = VideoWriter(video)
        if robust_nframes:
            nframes = vid.get_n_frames(robust=True)
            duration = vid.calc_duration(robust=True)
            fps = nframes / duration
        else:
            nframes = len(vid)
            duration = vid.calc_duration(robust=False)
            fps = vid.fps

        nx, ny = vid.dimensions
        print(
            "Duration of video [s]: ",
            round(duration, 2),
            ", recorded with ",
            round(fps, 2),
            "fps!",
        )
        print(
            "Overall # of frames: ",
            nframes,
            " found with (before cropping) frame dimensions: ",
            nx,
            ny,
        )
        start = time.time()

        print("Starting to extract posture")
        if int(dlc_cfg["batch_size"]) > 1:
            PredicteData, nframes = GetPoseandCostsF(
                cfg,
                dlc_cfg,
                sess,
                inputs,
                outputs,
                vid,
                nframes,
                int(dlc_cfg["batch_size"]),
                c_engine=c_engine,
            )
        else:
            PredicteData, nframes = GetPoseandCostsS(cfg,
                                                     dlc_cfg,
                                                     sess,
                                                     inputs,
                                                     outputs,
                                                     vid,
                                                     nframes,
                                                     c_engine=c_engine)

        stop = time.time()

        if cfg["cropping"] == True:
            coords = [cfg["x1"], cfg["x2"], cfg["y1"], cfg["y2"]]
        else:
            coords = [0, nx, 0, ny]

        dictionary = {
            "start": start,
            "stop": stop,
            "run_duration": stop - start,
            "Scorer": DLCscorer,
            "DLC-model-config file": dlc_cfg,
            "fps": fps,
            "batch_size": dlc_cfg["batch_size"],
            "frame_dimensions": (ny, nx),
            "nframes": nframes,
            "iteration (active-learning)": cfg["iteration"],
            "training set fraction": trainFraction,
            "cropping": cfg["cropping"],
            "cropping_parameters": coords,
        }
        metadata = {"data": dictionary}
        print("Saving results in %s..." % (destfolder))

        auxfun_multianimal.SaveFullMultiAnimalData(PredicteData, metadata,
                                                   dataname)
def ExtractFramesbasedonPreselection(
    Index,
    extractionalgorithm,
    data,
    video,
    cfg,
    config,
    opencv=True,
    cluster_resizewidth=30,
    cluster_color=False,
    savelabeled=True,
    with_annotations=True,
):
    from deeplabcut.create_project import add

    start = cfg["start"]
    stop = cfg["stop"]
    numframes2extract = cfg["numframes2pick"]
    bodyparts = auxiliaryfunctions.IntersectionofBodyPartsandOnesGivenbyUser(
        cfg, "all")

    videofolder = str(Path(video).parents[0])
    vname = str(Path(video).stem)
    tmpfolder = os.path.join(cfg["project_path"], "labeled-data", vname)
    if os.path.isdir(tmpfolder):
        print("Frames from video", vname,
              " already extracted (more will be added)!")
    else:
        auxiliaryfunctions.attempttomakefolder(tmpfolder, recursive=True)

    nframes = len(data)
    print("Loading video...")
    if opencv:
        vid = VideoWriter(video)
        fps = vid.fps
        duration = vid.calc_duration()
    else:
        from moviepy.editor import VideoFileClip

        clip = VideoFileClip(video)
        fps = clip.fps
        duration = clip.duration

    if cfg["cropping"]:  # one might want to adjust
        coords = (cfg["x1"], cfg["x2"], cfg["y1"], cfg["y2"])
    else:
        coords = None

    print("Duration of video [s]: ", duration, ", recorded @ ", fps, "fps!")
    print("Overall # of frames: ", nframes,
          "with (cropped) frame dimensions: ")
    if extractionalgorithm == "uniform":
        if opencv:
            frames2pick = frameselectiontools.UniformFramescv2(
                vid, numframes2extract, start, stop, Index)
        else:
            frames2pick = frameselectiontools.UniformFrames(
                clip, numframes2extract, start, stop, Index)
    elif extractionalgorithm == "kmeans":
        if opencv:
            frames2pick = frameselectiontools.KmeansbasedFrameselectioncv2(
                vid,
                numframes2extract,
                start,
                stop,
                cfg["cropping"],
                coords,
                Index,
                resizewidth=cluster_resizewidth,
                color=cluster_color,
            )
        else:
            if cfg["cropping"]:
                clip = clip.crop(y1=cfg["y1"],
                                 y2=cfg["x2"],
                                 x1=cfg["x1"],
                                 x2=cfg["x2"])
            frames2pick = frameselectiontools.KmeansbasedFrameselection(
                clip,
                numframes2extract,
                start,
                stop,
                Index,
                resizewidth=cluster_resizewidth,
                color=cluster_color,
            )

    else:
        print(
            "Please implement this method yourself! Currently the options are 'kmeans', 'jump', 'uniform'."
        )
        frames2pick = []

    # Extract frames + frames with plotted labels and store them in folder (with name derived from video name) nder labeled-data
    print("Let's select frames indices:", frames2pick)
    colors = visualization.get_cmap(len(bodyparts), cfg["colormap"])
    strwidth = int(np.ceil(np.log10(nframes)))  # width for strings
    for index in frames2pick:  ##tqdm(range(0,nframes,10)):
        if opencv:
            PlottingSingleFramecv2(
                vid,
                cfg["cropping"],
                coords,
                data,
                bodyparts,
                tmpfolder,
                index,
                cfg["dotsize"],
                cfg["pcutoff"],
                cfg["alphavalue"],
                colors,
                strwidth,
                savelabeled,
            )
        else:
            PlottingSingleFrame(
                clip,
                data,
                bodyparts,
                tmpfolder,
                index,
                cfg["dotsize"],
                cfg["pcutoff"],
                cfg["alphavalue"],
                colors,
                strwidth,
                savelabeled,
            )
        plt.close("all")

    # close videos
    if opencv:
        vid.close()
    else:
        clip.close()
        del clip

    # Extract annotations based on DeepLabCut and store in the folder (with name derived from video name) under labeled-data
    if len(frames2pick) > 0:
        try:
            if cfg["cropping"]:
                add.add_new_videos(
                    config, [video],
                    coords=[coords])  # make sure you pass coords as a list
            else:
                add.add_new_videos(config, [video], coords=None)
        except:  # can we make a catch here? - in fact we should drop indices from DataCombined if they are in CollectedData.. [ideal behavior; currently this is pretty unlikely]
            print(
                "AUTOMATIC ADDING OF VIDEO TO CONFIG FILE FAILED! You need to do this manually for including it in the config.yaml file!"
            )
            print("Videopath:", video, "Coordinates for cropping:", coords)
            pass

        if with_annotations:
            machinefile = os.path.join(
                tmpfolder,
                "machinelabels-iter" + str(cfg["iteration"]) + ".h5")
            if isinstance(data, pd.DataFrame):
                df = data.loc[frames2pick]
                df.index = [
                    os.path.join(
                        "labeled-data",
                        vname,
                        "img" + str(index).zfill(strwidth) + ".png",
                    ) for index in df.index
                ]  # exchange index number by file names.
            elif isinstance(data, dict):
                idx = [
                    os.path.join(
                        "labeled-data",
                        vname,
                        "img" + str(index).zfill(strwidth) + ".png",
                    ) for index in frames2pick
                ]
                filename = os.path.join(str(tmpfolder),
                                        f"CollectedData_{cfg['scorer']}.h5")
                try:
                    df_temp = pd.read_hdf(filename, "df_with_missing")
                    columns = df_temp.columns
                except FileNotFoundError:
                    columns = pd.MultiIndex.from_product(
                        [
                            [cfg["scorer"]],
                            cfg["individuals"],
                            cfg["multianimalbodyparts"],
                            ["x", "y"],
                        ],
                        names=["scorer", "individuals", "bodyparts", "coords"],
                    )
                    if cfg["uniquebodyparts"]:
                        columns2 = pd.MultiIndex.from_product(
                            [
                                [cfg["scorer"]],
                                ["single"],
                                cfg["uniquebodyparts"],
                                ["x", "y"],
                            ],
                            names=[
                                "scorer", "individuals", "bodyparts", "coords"
                            ],
                        )
                        df_temp = pd.concat((
                            pd.DataFrame(columns=columns),
                            pd.DataFrame(columns=columns2),
                        ))
                        columns = df_temp.columns
                array = np.full((len(frames2pick), len(columns)), np.nan)
                for i, index in enumerate(frames2pick):
                    data_temp = data.get(index)
                    if data_temp is not None:
                        vals = np.concatenate(data_temp)[:, :2].flatten()
                        array[i, :len(vals)] = vals
                df = pd.DataFrame(array, index=idx, columns=columns)
            else:
                return
            if Path(machinefile).is_file():
                Data = pd.read_hdf(machinefile, "df_with_missing")
                DataCombined = pd.concat([Data, df])
                # drop duplicate labels:
                DataCombined = DataCombined[~DataCombined.index.duplicated(
                    keep="first")]

                DataCombined.to_hdf(machinefile,
                                    key="df_with_missing",
                                    mode="w")
                DataCombined.to_csv(
                    os.path.join(tmpfolder, "machinelabels.csv")
                )  # this is always the most current one (as reading is from h5)
            else:
                df.to_hdf(machinefile, key="df_with_missing", mode="w")
                df.to_csv(os.path.join(tmpfolder, "machinelabels.csv"))

        print(
            "The outlier frames are extracted. They are stored in the subdirectory labeled-data\%s."
            % vname)
        print(
            "Once you extracted frames for all videos, use 'refine_labels' to manually correct the labels."
        )
    else:
        print("No frames were extracted.")
Beispiel #7
0
    def __init__(self, parent, config, video, shuffle, Dataframe, savelabeled,
                 multianimal):
        super(MainFrame,
              self).__init__("DeepLabCut2.0 - Manual Outlier Frame Extraction",
                             parent)

        ###################################################################################################################################################
        # Spliting the frame into top and bottom panels. Bottom panels contains the widgets. The top panel is for showing images and plotting!
        # topSplitter = wx.SplitterWindow(self)
        #
        # self.image_panel = ImagePanel(topSplitter, config,video,shuffle,Dataframe,self.gui_size)
        # self.widget_panel = WidgetPanel(topSplitter)
        #
        # topSplitter.SplitHorizontally(self.image_panel, self.widget_panel,sashPosition=self.gui_size[1]*0.83)#0.9
        # topSplitter.SetSashGravity(1)
        # sizer = wx.BoxSizer(wx.VERTICAL)
        # sizer.Add(topSplitter, 1, wx.EXPAND)
        # self.SetSizer(sizer)

        # Spliting the frame into top and bottom panels. Bottom panels contains the widgets. The top panel is for showing images and plotting!

        topSplitter = wx.SplitterWindow(self)
        vSplitter = wx.SplitterWindow(topSplitter)

        self.image_panel = ImagePanel(vSplitter, self.gui_size)
        self.choice_panel = ScrollPanel(vSplitter)

        vSplitter.SplitVertically(self.image_panel,
                                  self.choice_panel,
                                  sashPosition=self.gui_size[0] * 0.8)
        vSplitter.SetSashGravity(1)
        self.widget_panel = WidgetPanel(topSplitter)
        topSplitter.SplitHorizontally(vSplitter,
                                      self.widget_panel,
                                      sashPosition=self.gui_size[1] *
                                      0.83)  # 0.9
        topSplitter.SetSashGravity(1)
        sizer = wx.BoxSizer(wx.VERTICAL)
        sizer.Add(topSplitter, 1, wx.EXPAND)
        self.SetSizer(sizer)

        ###################################################################################################################################################
        # Add Buttons to the WidgetPanel and bind them to their respective functions.

        widgetsizer = wx.WrapSizer(orient=wx.HORIZONTAL)

        self.load_button_sizer = wx.BoxSizer(wx.VERTICAL)
        self.help_button_sizer = wx.BoxSizer(wx.VERTICAL)

        self.help = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Help")
        self.help_button_sizer.Add(self.help, 1, wx.ALL, 15)
        #        widgetsizer.Add(self.help , 1, wx.ALL, 15)
        self.help.Bind(wx.EVT_BUTTON, self.helpButton)

        widgetsizer.Add(self.help_button_sizer, 1, wx.ALL, 0)

        self.grab = wx.Button(self.widget_panel,
                              id=wx.ID_ANY,
                              label="Grab Frames")
        widgetsizer.Add(self.grab, 1, wx.ALL, 15)
        self.grab.Bind(wx.EVT_BUTTON, self.grabFrame)
        self.grab.Enable(True)

        widgetsizer.AddStretchSpacer(5)
        self.slider = wx.Slider(
            self.widget_panel,
            id=wx.ID_ANY,
            value=0,
            minValue=0,
            maxValue=1,
            size=(200, -1),
            style=wx.SL_HORIZONTAL | wx.SL_AUTOTICKS | wx.SL_LABELS,
        )
        widgetsizer.Add(self.slider, 1, wx.ALL, 5)
        self.slider.Bind(wx.EVT_SLIDER, self.OnSliderScroll)

        widgetsizer.AddStretchSpacer(5)
        self.start_frames_sizer = wx.BoxSizer(wx.VERTICAL)
        self.end_frames_sizer = wx.BoxSizer(wx.VERTICAL)

        self.start_frames_sizer.AddSpacer(15)
        #        self.startFrame = wx.SpinCtrl(self.widget_panel, value='0', size=(100, -1), min=0, max=120)
        self.startFrame = wx.SpinCtrl(self.widget_panel,
                                      value="0",
                                      size=(100, -1))  # ,style=wx.SP_VERTICAL)
        self.startFrame.Enable(False)
        self.start_frames_sizer.Add(self.startFrame, 1,
                                    wx.EXPAND | wx.ALIGN_LEFT, 15)
        start_text = wx.StaticText(self.widget_panel,
                                   label="Start Frame Index")
        self.start_frames_sizer.Add(start_text, 1, wx.EXPAND | wx.ALIGN_LEFT,
                                    15)
        self.checkBox = wx.CheckBox(self.widget_panel,
                                    id=wx.ID_ANY,
                                    label="Range of frames")
        self.checkBox.Bind(wx.EVT_CHECKBOX, self.activate_frame_range)
        self.start_frames_sizer.Add(self.checkBox, 1,
                                    wx.EXPAND | wx.ALIGN_LEFT, 15)
        #
        self.end_frames_sizer.AddSpacer(15)
        self.endFrame = wx.SpinCtrl(self.widget_panel,
                                    value="1",
                                    size=(160, -1))  # , min=1, max=120)
        self.endFrame.Enable(False)
        self.end_frames_sizer.Add(self.endFrame, 1, wx.EXPAND | wx.ALIGN_LEFT,
                                  15)
        end_text = wx.StaticText(self.widget_panel, label="Number of Frames")
        self.end_frames_sizer.Add(end_text, 1, wx.EXPAND | wx.ALIGN_LEFT, 15)
        self.updateFrame = wx.Button(self.widget_panel,
                                     id=wx.ID_ANY,
                                     label="Update")
        self.end_frames_sizer.Add(self.updateFrame, 1,
                                  wx.EXPAND | wx.ALIGN_LEFT, 15)
        self.updateFrame.Bind(wx.EVT_BUTTON, self.updateSlider)
        self.updateFrame.Enable(False)

        widgetsizer.Add(self.start_frames_sizer, 1, wx.ALL, 0)
        widgetsizer.AddStretchSpacer(5)
        widgetsizer.Add(self.end_frames_sizer, 1, wx.ALL, 0)
        widgetsizer.AddStretchSpacer(15)

        self.quit = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Quit")
        widgetsizer.Add(self.quit, 1, wx.ALL, 15)
        self.quit.Bind(wx.EVT_BUTTON, self.quitButton)
        self.quit.Enable(True)

        self.widget_panel.SetSizer(widgetsizer)
        self.widget_panel.SetSizerAndFit(widgetsizer)

        # Variables initialization
        self.numberFrames = 0
        self.currFrame = 0
        self.figure = Figure()
        self.axes = self.figure.add_subplot(111)
        self.drs = []
        self.extract_range_frame = False
        self.firstFrame = 0
        self.Colorscheme = []

        # Read confing file
        self.cfg = auxiliaryfunctions.read_config(config)
        self.Task = self.cfg["Task"]
        self.start = self.cfg["start"]
        self.stop = self.cfg["stop"]
        self.date = self.cfg["date"]
        self.trainFraction = self.cfg["TrainingFraction"]
        self.trainFraction = self.trainFraction[0]
        self.videos = self.cfg["video_sets"].keys()
        self.bodyparts = self.cfg["bodyparts"]
        self.colormap = plt.get_cmap(self.cfg["colormap"])
        self.colormap = self.colormap.reversed()
        self.markerSize = self.cfg["dotsize"]
        self.alpha = self.cfg["alphavalue"]
        self.iterationindex = self.cfg["iteration"]
        self.cropping = self.cfg["cropping"]
        self.video_names = [Path(i).stem for i in self.videos]
        self.config_path = Path(config)
        self.video_source = Path(video).resolve()
        self.shuffle = shuffle
        self.Dataframe = Dataframe
        self.savelabeled = savelabeled
        self.multianimal = multianimal
        if self.multianimal:
            from deeplabcut.utils import auxfun_multianimal

            (
                self.individual_names,
                self.uniquebodyparts,
                self.multianimalbodyparts,
            ) = auxfun_multianimal.extractindividualsandbodyparts(self.cfg)
            self.choiceBox, self.visualization_rdb = self.choice_panel.addRadioButtons(
            )
            self.Colorscheme = visualization.get_cmap(
                len(self.individual_names), self.cfg["colormap"])
            self.visualization_rdb.Bind(wx.EVT_RADIOBOX, self.clear_plot)
        # Read the video file
        self.vid = VideoWriter(str(self.video_source))
        if self.cropping:
            self.vid.set_bbox(self.cfg["x1"], self.cfg["x2"], self.cfg["y1"],
                              self.cfg["y2"])
        self.filename = Path(self.video_source).name
        self.numberFrames = len(self.vid)
        self.strwidth = int(np.ceil(np.log10(self.numberFrames)))
        # Set the values of slider and range of frames
        self.startFrame.SetMax(self.numberFrames - 1)
        self.slider.SetMax(self.numberFrames - 1)
        self.endFrame.SetMax(self.numberFrames - 1)
        self.startFrame.Bind(wx.EVT_SPINCTRL, self.updateSlider)  # wx.EVT_SPIN
        # Set the status bar
        self.statusbar.SetStatusText("Working on video: {}".format(
            self.filename))
        # Adding the video file to the config file.
        if self.vid.name not in self.video_names:
            add.add_new_videos(self.config_path, [self.video_source])

        self.update()
        self.plot_labels()
        self.widget_panel.Layout()
Beispiel #8
0
class MainFrame(BaseFrame):
    """Contains the main GUI and button boxes"""
    def __init__(self, parent, config, video, shuffle, Dataframe, savelabeled,
                 multianimal):
        super(MainFrame,
              self).__init__("DeepLabCut2.0 - Manual Outlier Frame Extraction",
                             parent)

        ###################################################################################################################################################
        # Spliting the frame into top and bottom panels. Bottom panels contains the widgets. The top panel is for showing images and plotting!
        # topSplitter = wx.SplitterWindow(self)
        #
        # self.image_panel = ImagePanel(topSplitter, config,video,shuffle,Dataframe,self.gui_size)
        # self.widget_panel = WidgetPanel(topSplitter)
        #
        # topSplitter.SplitHorizontally(self.image_panel, self.widget_panel,sashPosition=self.gui_size[1]*0.83)#0.9
        # topSplitter.SetSashGravity(1)
        # sizer = wx.BoxSizer(wx.VERTICAL)
        # sizer.Add(topSplitter, 1, wx.EXPAND)
        # self.SetSizer(sizer)

        # Spliting the frame into top and bottom panels. Bottom panels contains the widgets. The top panel is for showing images and plotting!

        topSplitter = wx.SplitterWindow(self)
        vSplitter = wx.SplitterWindow(topSplitter)

        self.image_panel = ImagePanel(vSplitter, self.gui_size)
        self.choice_panel = ScrollPanel(vSplitter)

        vSplitter.SplitVertically(self.image_panel,
                                  self.choice_panel,
                                  sashPosition=self.gui_size[0] * 0.8)
        vSplitter.SetSashGravity(1)
        self.widget_panel = WidgetPanel(topSplitter)
        topSplitter.SplitHorizontally(vSplitter,
                                      self.widget_panel,
                                      sashPosition=self.gui_size[1] *
                                      0.83)  # 0.9
        topSplitter.SetSashGravity(1)
        sizer = wx.BoxSizer(wx.VERTICAL)
        sizer.Add(topSplitter, 1, wx.EXPAND)
        self.SetSizer(sizer)

        ###################################################################################################################################################
        # Add Buttons to the WidgetPanel and bind them to their respective functions.

        widgetsizer = wx.WrapSizer(orient=wx.HORIZONTAL)

        self.load_button_sizer = wx.BoxSizer(wx.VERTICAL)
        self.help_button_sizer = wx.BoxSizer(wx.VERTICAL)

        self.help = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Help")
        self.help_button_sizer.Add(self.help, 1, wx.ALL, 15)
        #        widgetsizer.Add(self.help , 1, wx.ALL, 15)
        self.help.Bind(wx.EVT_BUTTON, self.helpButton)

        widgetsizer.Add(self.help_button_sizer, 1, wx.ALL, 0)

        self.grab = wx.Button(self.widget_panel,
                              id=wx.ID_ANY,
                              label="Grab Frames")
        widgetsizer.Add(self.grab, 1, wx.ALL, 15)
        self.grab.Bind(wx.EVT_BUTTON, self.grabFrame)
        self.grab.Enable(True)

        widgetsizer.AddStretchSpacer(5)
        self.slider = wx.Slider(
            self.widget_panel,
            id=wx.ID_ANY,
            value=0,
            minValue=0,
            maxValue=1,
            size=(200, -1),
            style=wx.SL_HORIZONTAL | wx.SL_AUTOTICKS | wx.SL_LABELS,
        )
        widgetsizer.Add(self.slider, 1, wx.ALL, 5)
        self.slider.Bind(wx.EVT_SLIDER, self.OnSliderScroll)

        widgetsizer.AddStretchSpacer(5)
        self.start_frames_sizer = wx.BoxSizer(wx.VERTICAL)
        self.end_frames_sizer = wx.BoxSizer(wx.VERTICAL)

        self.start_frames_sizer.AddSpacer(15)
        #        self.startFrame = wx.SpinCtrl(self.widget_panel, value='0', size=(100, -1), min=0, max=120)
        self.startFrame = wx.SpinCtrl(self.widget_panel,
                                      value="0",
                                      size=(100, -1))  # ,style=wx.SP_VERTICAL)
        self.startFrame.Enable(False)
        self.start_frames_sizer.Add(self.startFrame, 1,
                                    wx.EXPAND | wx.ALIGN_LEFT, 15)
        start_text = wx.StaticText(self.widget_panel,
                                   label="Start Frame Index")
        self.start_frames_sizer.Add(start_text, 1, wx.EXPAND | wx.ALIGN_LEFT,
                                    15)
        self.checkBox = wx.CheckBox(self.widget_panel,
                                    id=wx.ID_ANY,
                                    label="Range of frames")
        self.checkBox.Bind(wx.EVT_CHECKBOX, self.activate_frame_range)
        self.start_frames_sizer.Add(self.checkBox, 1,
                                    wx.EXPAND | wx.ALIGN_LEFT, 15)
        #
        self.end_frames_sizer.AddSpacer(15)
        self.endFrame = wx.SpinCtrl(self.widget_panel,
                                    value="1",
                                    size=(160, -1))  # , min=1, max=120)
        self.endFrame.Enable(False)
        self.end_frames_sizer.Add(self.endFrame, 1, wx.EXPAND | wx.ALIGN_LEFT,
                                  15)
        end_text = wx.StaticText(self.widget_panel, label="Number of Frames")
        self.end_frames_sizer.Add(end_text, 1, wx.EXPAND | wx.ALIGN_LEFT, 15)
        self.updateFrame = wx.Button(self.widget_panel,
                                     id=wx.ID_ANY,
                                     label="Update")
        self.end_frames_sizer.Add(self.updateFrame, 1,
                                  wx.EXPAND | wx.ALIGN_LEFT, 15)
        self.updateFrame.Bind(wx.EVT_BUTTON, self.updateSlider)
        self.updateFrame.Enable(False)

        widgetsizer.Add(self.start_frames_sizer, 1, wx.ALL, 0)
        widgetsizer.AddStretchSpacer(5)
        widgetsizer.Add(self.end_frames_sizer, 1, wx.ALL, 0)
        widgetsizer.AddStretchSpacer(15)

        self.quit = wx.Button(self.widget_panel, id=wx.ID_ANY, label="Quit")
        widgetsizer.Add(self.quit, 1, wx.ALL, 15)
        self.quit.Bind(wx.EVT_BUTTON, self.quitButton)
        self.quit.Enable(True)

        self.widget_panel.SetSizer(widgetsizer)
        self.widget_panel.SetSizerAndFit(widgetsizer)

        # Variables initialization
        self.numberFrames = 0
        self.currFrame = 0
        self.figure = Figure()
        self.axes = self.figure.add_subplot(111)
        self.drs = []
        self.extract_range_frame = False
        self.firstFrame = 0
        self.Colorscheme = []

        # Read confing file
        self.cfg = auxiliaryfunctions.read_config(config)
        self.Task = self.cfg["Task"]
        self.start = self.cfg["start"]
        self.stop = self.cfg["stop"]
        self.date = self.cfg["date"]
        self.trainFraction = self.cfg["TrainingFraction"]
        self.trainFraction = self.trainFraction[0]
        self.videos = self.cfg["video_sets"].keys()
        self.bodyparts = self.cfg["bodyparts"]
        self.colormap = plt.get_cmap(self.cfg["colormap"])
        self.colormap = self.colormap.reversed()
        self.markerSize = self.cfg["dotsize"]
        self.alpha = self.cfg["alphavalue"]
        self.iterationindex = self.cfg["iteration"]
        self.cropping = self.cfg["cropping"]
        self.video_names = [Path(i).stem for i in self.videos]
        self.config_path = Path(config)
        self.video_source = Path(video).resolve()
        self.shuffle = shuffle
        self.Dataframe = Dataframe
        self.savelabeled = savelabeled
        self.multianimal = multianimal
        if self.multianimal:
            from deeplabcut.utils import auxfun_multianimal

            (
                self.individual_names,
                self.uniquebodyparts,
                self.multianimalbodyparts,
            ) = auxfun_multianimal.extractindividualsandbodyparts(self.cfg)
            self.choiceBox, self.visualization_rdb = self.choice_panel.addRadioButtons(
            )
            self.Colorscheme = visualization.get_cmap(
                len(self.individual_names), self.cfg["colormap"])
            self.visualization_rdb.Bind(wx.EVT_RADIOBOX, self.clear_plot)
        # Read the video file
        self.vid = VideoWriter(str(self.video_source))
        if self.cropping:
            self.vid.set_bbox(self.cfg["x1"], self.cfg["x2"], self.cfg["y1"],
                              self.cfg["y2"])
        self.filename = Path(self.video_source).name
        self.numberFrames = len(self.vid)
        self.strwidth = int(np.ceil(np.log10(self.numberFrames)))
        # Set the values of slider and range of frames
        self.startFrame.SetMax(self.numberFrames - 1)
        self.slider.SetMax(self.numberFrames - 1)
        self.endFrame.SetMax(self.numberFrames - 1)
        self.startFrame.Bind(wx.EVT_SPINCTRL, self.updateSlider)  # wx.EVT_SPIN
        # Set the status bar
        self.statusbar.SetStatusText("Working on video: {}".format(
            self.filename))
        # Adding the video file to the config file.
        if self.vid.name not in self.video_names:
            add.add_new_videos(self.config_path, [self.video_source])

        self.update()
        self.plot_labels()
        self.widget_panel.Layout()

    def quitButton(self, event):
        """
        Quits the GUI
        """
        self.statusbar.SetStatusText("")
        dlg = wx.MessageDialog(None, "Are you sure?", "Quit!",
                               wx.YES_NO | wx.ICON_WARNING)
        result = dlg.ShowModal()
        if result == wx.ID_YES:
            print("Quitting for now!")
            self.Destroy()

    def updateSlider(self, event):
        self.slider.SetValue(self.startFrame.GetValue())
        self.startFrame.SetValue(self.slider.GetValue())
        self.axes.clear()
        self.figure.delaxes(self.figure.axes[1])
        self.grab.Bind(wx.EVT_BUTTON, self.grabFrame)
        self.currFrame = self.slider.GetValue()
        self.update()
        self.plot_labels()

    def activate_frame_range(self, event):
        """
        Activates the frame range boxes
        """
        self.checkSlider = event.GetEventObject()
        if self.checkSlider.GetValue():
            self.extract_range_frame = True
            self.startFrame.Enable(True)
            self.startFrame.SetValue(self.slider.GetValue())
            self.endFrame.Enable(True)
            self.updateFrame.Enable(True)
            self.grab.Enable(False)
        else:
            self.extract_range_frame = False
            self.startFrame.Enable(False)
            self.endFrame.Enable(False)
            self.updateFrame.Enable(False)
            self.grab.Enable(True)

    def line_select_callback(self, eclick, erelease):
        "eclick and erelease are the press and release events"
        self.new_x1, self.new_y1 = eclick.xdata, eclick.ydata
        self.new_x2, self.new_y2 = erelease.xdata, erelease.ydata

    def OnSliderScroll(self, event):
        """
        Slider to scroll through the video
        """
        self.axes.clear()
        self.figure.delaxes(self.figure.axes[1])
        self.grab.Bind(wx.EVT_BUTTON, self.grabFrame)
        self.currFrame = self.slider.GetValue()
        self.startFrame.SetValue(self.currFrame)
        self.update()
        self.plot_labels()

    def update(self):
        """
        Updates the image with the current slider index
        """
        self.grab.Enable(True)
        self.grab.Bind(wx.EVT_BUTTON, self.grabFrame)
        self.figure, self.axes, self.canvas = self.image_panel.getfigure()
        self.vid.set_to_frame(self.currFrame)
        frame = self.vid.read_frame(crop=self.cropping)
        if frame is not None:
            frame = img_as_ubyte(frame)
            self.ax = self.axes.imshow(frame, cmap=self.colormap)
            self.axes.set_title(
                str(
                    str(self.currFrame) + "/" + str(self.numberFrames - 1) +
                    " " + self.filename))
            self.figure.canvas.draw()
        else:
            print("Invalid frame")

    def chooseFrame(self):
        frame = img_as_ubyte(self.vid.read_frame(crop=self.cropping))
        fname = Path(self.filename)
        output_path = self.config_path.parents[0] / "labeled-data" / fname.stem

        self.machinefile = os.path.join(
            str(output_path),
            "machinelabels-iter" + str(self.iterationindex) + ".h5")
        name = str(fname.stem)
        DF = self.Dataframe.iloc[[self.currFrame]]
        DF.index = [
            os.path.join("labeled-data", name,
                         "img" + str(index).zfill(self.strwidth) + ".png")
            for index in DF.index
        ]
        img_name = (str(output_path) + "/img" + str(self.currFrame).zfill(
            int(np.ceil(np.log10(self.numberFrames)))) + ".png")
        labeled_img_name = (str(output_path) + "/img" + str(
            self.currFrame).zfill(int(np.ceil(np.log10(self.numberFrames)))) +
                            "labeled.png")

        # Check for it output path and a machine label file exist
        if output_path.exists() and Path(self.machinefile).is_file():
            io.imsave(img_name, frame)
            if self.savelabeled:
                self.figure.savefig(labeled_img_name, bbox_inches="tight")
            Data = pd.read_hdf(self.machinefile)
            DataCombined = pd.concat([Data, DF])
            DataCombined = DataCombined[~DataCombined.index.duplicated(
                keep="first")]
            DataCombined.to_hdf(self.machinefile,
                                key="df_with_missing",
                                mode="w")
            DataCombined.to_csv(
                os.path.join(str(output_path), "machinelabels.csv"))
        # If machine label file does not exist then create one
        elif output_path.exists() and not (Path(self.machinefile).is_file()):
            if self.savelabeled:
                self.figure.savefig(labeled_img_name, bbox_inches="tight")
            io.imsave(img_name, frame)
            #            cv2.imwrite(img_name, frame)
            DF.to_hdf(self.machinefile, key="df_with_missing", mode="w")
            DF.to_csv(os.path.join(str(output_path), "machinelabels.csv"))
        else:
            print(
                "%s path not found. Please make sure that the video was added to the config file using the function 'deeplabcut.add_new_videos'.Quitting for now!"
                % output_path)
            self.Destroy()

    def grabFrame(self, event):
        """
        Extracts the frame and saves in the current directory
        """

        if self.extract_range_frame:
            num_frames_extract = self.endFrame.GetValue()
            for i in range(self.currFrame,
                           self.currFrame + num_frames_extract):
                self.currFrame = i
                self.vid.set_to_frame(self.currFrame)
                self.chooseFrame()
        else:
            self.vid.set_to_frame(self.currFrame)
            self.chooseFrame()

    def clear_plot(self, event):
        self.figure.delaxes(self.figure.axes[1])
        [p.remove() for p in reversed(self.axes.patches)]
        self.plot_labels()

    def plot_labels(self):
        """
        Plots the labels of the analyzed video
        """
        self.vid.set_to_frame(self.currFrame)
        frame = self.vid.read_frame()
        if frame is not None:
            divider = make_axes_locatable(self.axes)
            cax = divider.append_axes("right", size="5%", pad=0.05)
            if self.multianimal:
                # take into account of all the bodyparts for the colorscheme. Sort the bodyparts to have same order as in the config file
                self.all_bodyparts = np.array(self.multianimalbodyparts +
                                              self.uniquebodyparts)
                _, return_idx = np.unique(self.all_bodyparts,
                                          return_index=True)
                self.all_bodyparts = list(
                    self.all_bodyparts[np.sort(return_idx)])

                if (self.visualization_rdb.GetSelection() == 0
                    ):  # i.e. for color scheme for individuals
                    self.Colorscheme = visualization.get_cmap(
                        len(self.individual_names), self.cfg["colormap"])
                    self.norm, self.colorIndex = self.image_panel.getColorIndices(
                        frame, self.individual_names)
                    cbar = self.figure.colorbar(self.ax,
                                                cax=cax,
                                                spacing="proportional",
                                                ticks=self.colorIndex)
                    cbar.set_ticklabels(self.individual_names)
                else:  # i.e. for color scheme for all bodyparts
                    self.Colorscheme = visualization.get_cmap(
                        len(self.all_bodyparts), self.cfg["colormap"])
                    self.norm, self.colorIndex = self.image_panel.getColorIndices(
                        frame, self.all_bodyparts)
                    cbar = self.figure.colorbar(self.ax,
                                                cax=cax,
                                                spacing="proportional",
                                                ticks=self.colorIndex)
                    cbar.set_ticklabels(self.all_bodyparts)

                for ci, ind in enumerate(self.individual_names):
                    col_idx = (
                        0
                    )  # variable for iterating through the colorscheme for all bodyparts
                    image_points = []
                    if ind == "single":
                        if self.visualization_rdb.GetSelection() == 0:
                            for c, bp in enumerate(self.uniquebodyparts):
                                pts = self.Dataframe.xs(
                                    (ind, bp),
                                    level=("individuals", "bodyparts"),
                                    axis=1,
                                ).values
                                self.circle = patches.Circle(
                                    pts[self.currFrame, :2],
                                    radius=self.markerSize,
                                    fc=self.Colorscheme(ci),
                                    alpha=self.alpha,
                                )
                                self.axes.add_patch(self.circle)
                        else:
                            for c, bp in enumerate(self.uniquebodyparts):
                                pts = self.Dataframe.xs(
                                    (ind, bp),
                                    level=("individuals", "bodyparts"),
                                    axis=1,
                                ).values
                                self.circle = patches.Circle(
                                    pts[self.currFrame, :2],
                                    radius=self.markerSize,
                                    fc=self.Colorscheme(col_idx),
                                    alpha=self.alpha,
                                )
                                self.axes.add_patch(self.circle)
                                col_idx = col_idx + 1
                    else:
                        if self.visualization_rdb.GetSelection() == 0:
                            for c, bp in enumerate(self.multianimalbodyparts):
                                pts = self.Dataframe.xs(
                                    (ind, bp),
                                    level=("individuals", "bodyparts"),
                                    axis=1,
                                ).values
                                self.circle = patches.Circle(
                                    pts[self.currFrame, :2],
                                    radius=self.markerSize,
                                    fc=self.Colorscheme(ci),
                                    alpha=self.alpha,
                                )
                                self.axes.add_patch(self.circle)
                        else:
                            for c, bp in enumerate(self.multianimalbodyparts):
                                pts = self.Dataframe.xs(
                                    (ind, bp),
                                    level=("individuals", "bodyparts"),
                                    axis=1,
                                ).values
                                self.circle = patches.Circle(
                                    pts[self.currFrame, :2],
                                    radius=self.markerSize,
                                    fc=self.Colorscheme(col_idx),
                                    alpha=self.alpha,
                                )
                                self.axes.add_patch(self.circle)
                                col_idx = col_idx + 1
                self.figure.canvas.draw()
            else:
                self.norm, self.colorIndex = self.image_panel.getColorIndices(
                    frame, self.bodyparts)
                cbar = self.figure.colorbar(self.ax,
                                            cax=cax,
                                            spacing="proportional",
                                            ticks=self.colorIndex)
                cbar.set_ticklabels(self.bodyparts)
                for bpindex, bp in enumerate(self.bodyparts):
                    color = self.colormap(self.norm(self.colorIndex[bpindex]))
                    self.points = [
                        self.Dataframe.xs((bp, "x"), level=(-2, -1),
                                          axis=1).values[self.currFrame],
                        self.Dataframe.xs((bp, "y"), level=(-2, -1),
                                          axis=1).values[self.currFrame],
                        1.0,
                    ]
                    circle = [
                        patches.Circle(
                            (self.points[0], self.points[1]),
                            radius=self.markerSize,
                            fc=color,
                            alpha=self.alpha,
                        )
                    ]
                    self.axes.add_patch(circle[0])
                self.figure.canvas.draw()
        else:
            print("Invalid frame")

    def helpButton(self, event):
        """
        Opens Instructions
        """
        wx.MessageBox(
            "1. Use the slider to select a frame in the entire video. \n\n2. Click Grab Frames button to save the specific frame.\
        \n\n3. In the events where you need to extract a range of frames, then use the checkbox 'Range of frames' to select the starting frame index and the number of frames to extract.\
        \n Click the update button to see the frame. Click Grab Frames to select the range of frames. \n\n Click OK to continue",
            "Instructions to use!",
            wx.OK | wx.ICON_INFORMATION,
        )
def extract_bpt_feature_from_video(
    video,
    DLCscorer,
    trainFraction,
    cfg,
    dlc_cfg,
    sess,
    inputs,
    outputs,
    extra_dict,
    destfolder=None,
    robust_nframes=False,
):
    print("Starting to analyze % ", video)
    vname = Path(video).stem
    videofolder = str(Path(video).parents[0])
    if destfolder is None:
        destfolder = videofolder
    auxiliaryfunctions.attempttomakefolder(destfolder)
    dataname = os.path.join(destfolder, vname + DLCscorer + ".h5")

    assemble_filename = dataname.split(".h5")[0] + "_assemblies.pickle"

    feature_dict = shelve.open(
        dataname.split(".h5")[0] + "_bpt_features.pickle",
        protocol=pickle.DEFAULT_PROTOCOL,
    )

    with open(assemble_filename, "rb") as f:
        assemblies = pickle.load(f)
        print("Loading ", video)
        vid = VideoWriter(video)
        if robust_nframes:
            nframes = vid.get_n_frames(robust=True)
            duration = vid.calc_duration(robust=True)
            fps = nframes / duration
        else:
            nframes = len(vid)
            duration = vid.calc_duration(robust=False)
            fps = vid.fps

        nx, ny = vid.dimensions
        print(
            "Duration of video [s]: ",
            round(duration, 2),
            ", recorded with ",
            round(fps, 2),
            "fps!",
        )
        print(
            "Overall # of frames: ",
            nframes,
            " found with (before cropping) frame dimensions: ",
            nx,
            ny,
        )
        start = time.time()

        print("Starting to extract posture")
        if int(dlc_cfg["batch_size"]) > 1:
            # for multi animal, seems only this is used
            PredicteData, nframes = GetPoseandCostsF_from_assemblies(
                cfg,
                dlc_cfg,
                sess,
                inputs,
                outputs,
                vid,
                nframes,
                int(dlc_cfg["batch_size"]),
                assemblies,
                feature_dict,
                extra_dict,
            )
        else:
            raise NotImplementedError(
                "Not implemented yet, please raise an GitHub issue if you need this."
            )
Beispiel #10
0
def test_writer_shorten(tmp_path, video_clip):
    file = video_clip.shorten("00:00:00", "00:00:02", dest_folder=str(tmp_path))
    vid = VideoWriter(file)
    assert pytest.approx(vid.calc_duration(), abs=0.1) == 2
Beispiel #11
0
def video_clip():
    return VideoWriter(video)
Beispiel #12
0
def test_writer_crop(tmp_path, video_clip):
    x1, x2, y1, y2 = 0, 50, 0, 100
    video_clip.set_bbox(x1, x2, y1, y2)
    file = video_clip.crop(dest_folder=str(tmp_path))
    vid = VideoWriter(file)
    assert vid.dimensions == (x2 - x1, y2 - y1)
Beispiel #13
0
def stitch_tracklets(
    config_path,
    videos,
    videotype="avi",
    shuffle=1,
    trainingsetindex=0,
    n_tracks=None,
    min_length=10,
    split_tracklets=True,
    prestitch_residuals=True,
    max_gap=None,
    weight_func=None,
    destfolder=None,
    modelprefix="",
    track_method="",
    output_name="",
    transformer_checkpoint="",
):
    """
    Stitch sparse tracklets into full tracks via a graph-based,
    minimum-cost flow optimization problem.

    Parameters
    ----------
    config_path : str
        Path to the main project config.yaml file.

    videos : list
        A list of strings containing the full paths to videos for analysis or a path to the directory, where all the videos with same extension are stored.

    videotype: string, optional
        Checks for the extension of the video in case the input to the video is a directory.\n Only videos with this extension are analyzed. The default is ``.avi``

    shuffle: int, optional
        An integer specifying the shuffle index of the training dataset used for training the network. The default is 1.

    trainingsetindex: int, optional
        Integer specifying which TrainingsetFraction to use. By default the first (note that TrainingFraction is a list in config.yaml).

    n_tracks : int, optional
        Number of tracks to reconstruct. By default, taken as the number
        of individuals defined in the config.yaml. Another number can be
        passed if the number of animals in the video is different from
        the number of animals the model was trained on.

    min_length : int, optional
        Tracklets less than `min_length` frames of length
        are considered to be residuals; i.e., they do not participate
        in building the graph and finding the solution to the
        optimization problem, but are rather added last after
        "almost-complete" tracks are formed. The higher the value,
        the lesser the computational cost, but the higher the chance of
        discarding relatively long and reliable tracklets that are
        essential to solving the stitching task.
        Default is 10, and must be 3 at least.

    split_tracklets : bool, optional
        By default, tracklets whose time indices are not consecutive integers
        are split in shorter tracklets whose time continuity is guaranteed.
        This is for example very powerful to get rid of tracking errors
        (e.g., identity switches) which are often signaled by a missing
        time frame at the moment they occur. Note though that for long
        occlusions where tracker re-identification capability can be trusted,
        setting `split_tracklets` to False is preferable.

    prestitch_residuals : bool, optional
        Residuals will by default be grouped together according to their
        temporal proximity prior to being added back to the tracks.
        This is done to improve robustness and simultaneously reduce complexity.

    max_gap : int, optional
        Maximal temporal gap to allow between a pair of tracklets.
        This is automatically determined by the TrackletStitcher by default.

    weight_func : callable, optional
        Function accepting two tracklets as arguments and returning a scalar
        that must be inversely proportional to the likelihood that the tracklets
        belong to the same track; i.e., the higher the confidence that the
        tracklets should be stitched together, the lower the returned value.

    destfolder: string, optional
        Specifies the destination folder for analysis data (default is the path of the video). Note that for subsequent analysis this
        folder also needs to be passed.

    track_method: string, optional
         Specifies the tracker used to generate the pose estimation data.
         For multiple animals, must be either 'box', 'skeleton', or 'ellipse'
         and will be taken from the config.yaml file if none is given.

    output_name : str, optional
        Name of the output h5 file.
        By default, tracks are automatically stored into the same directory
        as the pickle file and with its name.

    Returns
    -------
    A TrackletStitcher object
    """
    vids = deeplabcut.utils.auxiliaryfunctions.Getlistofvideos(
        videos, videotype)
    if not vids:
        print("No video(s) found. Please check your path!")
        return

    cfg = auxiliaryfunctions.read_config(config_path)
    track_method = auxfun_multianimal.get_track_method(
        cfg, track_method=track_method)

    animal_names = cfg["individuals"]
    if n_tracks is None:
        n_tracks = len(animal_names)

    DLCscorer, _ = deeplabcut.utils.auxiliaryfunctions.GetScorerName(
        cfg,
        shuffle,
        cfg["TrainingFraction"][trainingsetindex],
        modelprefix=modelprefix,
    )

    if transformer_checkpoint:
        from deeplabcut.pose_tracking_pytorch import inference

        dlctrans = inference.DLCTrans(checkpoint=transformer_checkpoint)

    def trans_weight_func(tracklet1, tracklet2, nframe, feature_dict):
        zfill_width = int(np.ceil(np.log10(nframe)))
        if tracklet1 < tracklet2:
            ind_img1 = tracklet1.inds[-1]
            coord1 = tracklet1.data[-1][:, :2]
            ind_img2 = tracklet2.inds[0]
            coord2 = tracklet2.data[0][:, :2]
        else:
            ind_img2 = tracklet2.inds[-1]
            ind_img1 = tracklet1.inds[0]
            coord2 = tracklet2.data[-1][:, :2]
            coord1 = tracklet1.data[0][:, :2]
        t1 = (coord1, ind_img1)
        t2 = (coord2, ind_img2)

        dist = dlctrans(t1, t2, zfill_width, feature_dict)
        dist = (dist + 1) / 2

        return -dist

    for video in vids:
        print("Processing... ", video)
        nframe = len(VideoWriter(video))
        videofolder = str(Path(video).parents[0])
        dest = destfolder or videofolder
        deeplabcut.utils.auxiliaryfunctions.attempttomakefolder(dest)
        vname = Path(video).stem

        feature_dict_path = os.path.join(
            videofolder, vname + DLCscorer + "_bpt_features.pickle")
        # should only exist one
        if transformer_checkpoint:
            import dbm
            try:
                feature_dict = shelve.open(feature_dict_path, flag='r')
            except dbm.error:
                raise FileNotFoundError(
                    f'{feature_dict_path} does not exist. Did you run transformer_reID()?'
                )

        dataname = os.path.join(dest, vname + DLCscorer + ".h5")

        if track_method == "ellipse":
            method = "el"
        elif track_method == "box":
            method = "bx"
        else:
            method = "sk"
        pickle_file = dataname.split(".h5")[0] + f"_{method}.pickle"
        try:
            stitcher = TrackletStitcher.from_pickle(pickle_file, n_tracks,
                                                    min_length,
                                                    split_tracklets,
                                                    prestitch_residuals)
            with_id = any(tracklet.identity != -1 for tracklet in stitcher)
            if with_id and weight_func is None:
                # Add in identity weighing before building the graph
                def weight_func(t1, t2):
                    w = 0.01 if t1.identity == t2.identity else 1
                    return w * stitcher.calculate_edge_weight(t1, t2)

            if transformer_checkpoint:
                stitcher.build_graph(
                    max_gap=max_gap,
                    weight_func=partial(trans_weight_func,
                                        nframe=nframe,
                                        feature_dict=feature_dict),
                )
            else:
                stitcher.build_graph(max_gap=max_gap, weight_func=weight_func)

            stitcher.stitch()
            if transformer_checkpoint:
                stitcher.write_tracks(output_name=output_name,
                                      animal_names=animal_names,
                                      suffix="transformer")
            else:
                stitcher.write_tracks(output_name=output_name,
                                      animal_names=animal_names,
                                      suffix="")
        except FileNotFoundError as e:
            print(e, "\nSkipping...")
Beispiel #14
0
def video_clip():
    return VideoWriter(os.path.join(TEST_DATA_DIR, "vid.avi"))