Example #1
0
def test_frame_merge_between_predicted_and_user(skeleton, centered_pair_vid):
    user_inst = Instance(skeleton=skeleton, points={skeleton.nodes[0]: Point(1, 2)},)
    user_labels = Labels(
        [LabeledFrame(video=centered_pair_vid, frame_idx=0, instances=[user_inst],)]
    )

    pred_inst = PredictedInstance(
        skeleton=skeleton,
        points={skeleton.nodes[0]: PredictedPoint(1, 2, score=1.0)},
        score=1.0,
    )
    pred_labels = Labels(
        [LabeledFrame(video=centered_pair_vid, frame_idx=0, instances=[pred_inst],)]
    )

    # Merge predictions into current labels dataset
    _, _, new_conflicts = Labels.complex_merge_between(
        user_labels,
        new_labels=pred_labels,
        unify=False,  # since we used match_to when loading predictions file
    )

    # new predictions should replace old ones
    Labels.finish_complex_merge(user_labels, new_conflicts)

    # We should be able to cleanly merge the user and the predicted instance,
    # and we want to retain both even though they perfectly match.
    assert user_inst in user_labels[0].instances
    assert pred_inst in user_labels[0].instances
    assert len(user_labels[0].instances) == 2
Example #2
0
def fit_tracks(filename: Text, instance_count: int):
    """Wraps `TrackCleaner` for easier cli api."""

    labels = Labels.load_file(filename)
    video = labels.videos[0]
    frames = labels.find(video)

    TrackCleaner(instance_count=instance_count).run(frames=frames)

    # Rebuild list of tracks
    labels.tracks = list({
        instance.track
        for frame in labels for instance in frame.instances if instance.track
    })

    labels.tracks.sort(key=operator.attrgetter("spawned_on", "name"))

    # Save new file
    save_filename = filename
    save_filename = save_filename.replace(".slp", ".cleaned.slp")
    save_filename = save_filename.replace(".h5", ".cleaned.h5")
    save_filename = save_filename.replace(".json", ".cleaned.json")
    Labels.save_file(labels, save_filename)

    print(f"Saved: {save_filename}")
Example #3
0
def retrack():
    import argparse
    import operator
    import os
    import time

    from sleap import Labels

    parser = argparse.ArgumentParser()

    parser.add_argument("data_path", help="Path to SLEAP project file")
    parser.add_argument(
        "-o",
        "--output",
        type=str,
        default=None,
        help="The output filename to use for the predicted data.",
    )

    Tracker.add_cli_parser_args(parser)

    args = parser.parse_args()

    tracker_args = {
        key: val
        for key, val in vars(args).items() if val is not None
    }

    tracker = Tracker.make_tracker_by_name(**tracker_args)

    print(tracker)

    print("Loading predictions...")
    t0 = time.time()
    labels = Labels.load_file(args.data_path, args.data_path)
    frames = sorted(labels.labeled_frames,
                    key=operator.attrgetter("frame_idx"))
    frames = frames  # [:1000]
    print(f"Done loading predictions in {time.time() - t0} seconds.")

    print("Starting tracker...")
    frames = run_tracker(frames=frames, tracker=tracker)
    tracker.final_pass(frames)

    new_labels = Labels(labeled_frames=frames)

    if args.output:
        output_path = args.output
    else:
        out_dir = os.path.dirname(args.data_path)
        out_name = os.path.basename(
            args.data_path) + f".{tracker.get_name()}.slp"
        output_path = os.path.join(out_dir, out_name)

    print(f"Saving: {output_path}")
    Labels.save_file(new_labels, output_path)
Example #4
0
def find_frame_pairs(labels_gt: Labels,
                     labels_pr: Labels,
                     user_labels_only: bool = True
                     ) -> List[Tuple[LabeledFrame, LabeledFrame]]:
    """Find corresponding frames across two sets of labels.

    Args:
        labels_gt: A `sleap.Labels` instance with ground truth instances.
        labels_pr: A `sleap.Labels` instance with predicted instances.
        user_labels_only: If False, frames with predicted instances in `labels_gt` will
            also be considered for matching.

    Returns:
        A list of pairs of `sleap.LabeledFrame`s in the form `(frame_gt, frame_pr)`.
    """
    frame_pairs = []
    for video_gt in labels_gt.videos:

        # Find matching video instance in predictions.
        video_pr = None
        for video in labels_pr.videos:
            if isinstance(video.backend, type(
                    video_gt.backend)) and video.matches(video_gt):
                video_pr = video
                break

        if video_pr is None:
            continue

        # Find labeled frames in this video.
        labeled_frames_gt = labels_gt.find(video_gt)
        if user_labels_only:
            labeled_frames_gt = [
                lf for lf in labeled_frames_gt if lf.has_user_instances
            ]

        # Attempt to match each labeled frame in the ground truth.
        for labeled_frame_gt in labeled_frames_gt:
            labeled_frames_pr = labels_pr.find(
                video_pr, frame_idx=labeled_frame_gt.frame_idx)

            if not labeled_frames_pr:
                # No match
                continue
            elif len(labeled_frames_pr) == 1:
                # Match!
                frame_pairs.append((labeled_frame_gt, labeled_frames_pr[0]))
            else:
                # Too many matches.
                raise ValueError(
                    "More than one labeled frame found in predictions.")

    return frame_pairs
Example #5
0
    def merge_results(self):
        """Merges result frames into labels dataset."""
        # Remove any frames without instances
        new_lfs = list(filter(lambda lf: len(lf.instances), self.results))

        # Merge predictions into current labels dataset
        _, _, new_conflicts = Labels.complex_merge_between(
            self.labels,
            new_labels=Labels(new_lfs),
            unify=False,  # since we used match_to when loading predictions file
        )

        # new predictions should replace old ones
        Labels.finish_complex_merge(self.labels, new_conflicts)
Example #6
0
def demo_gui():
    from sleap.gui.dialogs.formbuilder import YamlFormWidget
    from sleap import Labels
    from PySide2.QtWidgets import QApplication

    labels = Labels.load_file(
        "tests/data/json_format_v2/centered_pair_predictions.json"
    )

    options_lists = dict(node=labels.skeletons[0].node_names)

    app = QApplication()
    win = YamlFormWidget.from_name(
        "suggestions", title="Generate Suggestions", field_options_lists=options_lists
    )

    def demo_suggestions(params):
        print(params)
        x = VideoFrameSuggestions.suggest(params=params, labels=labels)

        for suggested_frame in x:
            print(
                suggested_frame.video.backend.filename,
                suggested_frame.frame_idx,
                suggested_frame.group,
            )

    win.mainAction.connect(demo_suggestions)
    win.show()

    app.exec_()
Example #7
0
    def predict_subprocess(
        self,
        item_for_inference: ItemForInference,
        append_results: bool = False,
        waiting_callback: Optional[Callable] = None,
    ) -> Tuple[Text, bool]:
        """Runs inference in a subprocess."""
        cli_args, output_path = self.make_predict_cli_call(item_for_inference)

        print("Command line call:")
        print(" \\\n".join(cli_args))
        print()

        with sub.Popen(cli_args) as proc:
            while proc.poll() is None:
                if waiting_callback is not None:

                    if waiting_callback() == -1:
                        # -1 signals user cancellation
                        return "", False

                time.sleep(0.1)

            print(f"Process return code: {proc.returncode}")
            success = proc.returncode == 0

        if success and append_results:
            # Load frames from inference into results list
            new_inference_labels = Labels.load_file(output_path,
                                                    match_to=self.labels)
            self.results.extend(new_inference_labels.labeled_frames)

        return output_path, success
Example #8
0
    def read(cls, file: FileHandle, *args, **kwargs,) -> Labels:
        filename = file.filename

        # Load data from the YAML file
        project_data = yaml.load(file.text, Loader=yaml.SafeLoader)

        # Create skeleton which we'll use for each video
        skeleton = Skeleton()
        skeleton.add_nodes(project_data["bodyparts"])

        # Get subdirectories of videos and labeled data
        root_dir = os.path.dirname(filename)
        videos_dir = os.path.join(root_dir, "videos")
        labeled_data_dir = os.path.join(root_dir, "labeled-data")

        with os.scandir(labeled_data_dir) as file_iterator:
            data_subdirs = [file.path for file in file_iterator if file.is_dir()]

        labeled_frames = []

        # Each subdirectory of labeled data corresponds to a video.
        # We'll go through each and import the labeled frames.

        for data_subdir in data_subdirs:
            csv_files = find_files_by_suffix(
                data_subdir, prefix="CollectedData", suffix=".csv"
            )

            if csv_files:
                csv_path = csv_files[0]

                # Try to find a full video corresponding to this subdir.
                # If subdirectory is foo, we look for foo.mp4 in videos dir.

                shortname = os.path.split(data_subdir)[-1]
                video_path = os.path.join(videos_dir, f"{shortname}.mp4")

                if os.path.exists(video_path):
                    video = Video.from_filename(video_path)
                else:
                    # When no video is found, the individual frame images
                    # stored in the labeled data subdir will be used.
                    print(
                        f"Unable to find {video_path} so using individual frame images."
                    )
                    video = None

                # Import the labeled fraems
                labeled_frames.extend(
                    LabelsDeepLabCutCsvAdaptor.read_frames(
                        FileHandle(csv_path), full_video=video, skeleton=skeleton
                    )
                )

            else:
                print(f"No csv data file found in {data_subdir}")

        return Labels(labeled_frames=labeled_frames)
Example #9
0
 def read(
     cls,
     file: FileHandle,
     full_video: Optional[Video] = None,
     *args,
     **kwargs,
 ) -> Labels:
     return Labels(
         labeled_frames=cls.read_frames(file, full_video, *args, **kwargs))
Example #10
0
 def from_unlabeled_suggestions(cls, labels: sleap.Labels) -> "LabelsReader":
     """Create a `LabelsReader` using the unlabeled suggestions in a `Labels` set.
     Args:
         labels: A `sleap.Labels` instance containing unlabeled suggestions.
     Returns:
         A `LabelsReader` instance that can create a dataset for pipelining.
     """
     inds = labels.get_unlabeled_suggestion_inds()
     return cls(labels=labels, example_indices=inds)
Example #11
0
    def read_headers(
        cls,
        file: format.filehandle.FileHandle,
        video_search: Union[Callable, List[Text], None] = None,
        match_to: Optional[Labels] = None,
    ):
        f = file.file

        # Extract the Labels JSON metadata and create Labels object with just this
        # metadata.
        dicts = json_loads(
            f.require_group("metadata").attrs["json"].tostring().decode()
        )

        # These items are stored in separate lists because the metadata group got to be
        # too big.
        for key in ("videos", "tracks", "suggestions"):
            hdf5_key = f"{key}_json"
            if hdf5_key in f:
                items = [json_loads(item_json) for item_json in f[hdf5_key]]
                dicts[key] = items

        # Video path "." means the video is saved in same file as labels, so replace
        # these paths.
        for video_item in dicts["videos"]:
            if video_item["backend"]["filename"] == ".":
                video_item["backend"]["filename"] = file.filename

        # Use the video_callback for finding videos with broken paths:

        # 1. Accept single string as video search path
        if isinstance(video_search, str):
            video_search = [video_search]

        # 2. Accept list of strings as video search paths
        if hasattr(video_search, "__iter__"):
            # If the callback is an iterable, then we'll expect it to be a list of
            # strings and build a non-gui callback with those as the search paths.
            search_paths = [
                # os.path.dirname(path) if os.path.isfile(path) else path
                path
                for path in video_search
            ]

            # Make the search function from list of paths
            video_search = Labels.make_video_callback(search_paths)

        # 3. Use the callback function (either given as arg or build from paths)
        if callable(video_search):
            video_search(dicts["videos"])

        # Create the Labels object with the header data we've loaded
        labels = labels_json.LabelsJsonAdaptor.from_json_data(dicts, match_to=match_to)

        return labels
Example #12
0
    def predict_subprocess(
        self,
        item_for_inference: ItemForInference,
        append_results: bool = False,
        waiting_callback: Optional[Callable] = None,
        gui: bool = True,
    ) -> Tuple[Text, bool]:
        """Runs inference in a subprocess."""
        cli_args, output_path = self.make_predict_cli_call(item_for_inference, gui=gui)

        print("Command line call:")
        print(" ".join(cli_args))
        print()

        # Run inference CLI capturing output.
        with subprocess.Popen(cli_args, stdout=subprocess.PIPE) as proc:

            # Poll until finished.
            while proc.poll() is None:

                # Read line.
                line = proc.stdout.readline()
                line = line.decode().rstrip()

                if line.startswith("{"):
                    # Parse line.
                    line_data = json.loads(line)
                else:
                    # Pass through non-json output.
                    print(line)
                    line_data = {}

                if waiting_callback is not None:
                    # Pass line data to callback.
                    ret = waiting_callback(**line_data)

                    if ret == "cancel":
                        # Stop if callback returned cancel signal.
                        kill_process(proc.pid)
                        print(f"Killed PID: {proc.pid}")
                        return "", "canceled"
                time.sleep(0.05)

            print(f"Process return code: {proc.returncode}")
            success = proc.returncode == 0

        if success and append_results:
            # Load frames from inference into results list
            new_inference_labels = Labels.load_file(output_path, match_to=self.labels)
            self.results.extend(new_inference_labels.labeled_frames)

        # Return "success" or return code if failed.
        ret = "success" if success else proc.returncode
        return output_path, ret
Example #13
0
def save_predictions_from_cli(args, predicted_frames, prediction_metadata=None):
    from sleap import Labels

    if args.output:
        output_path = args.output
    elif args.video_path:
        out_dir = os.path.dirname(args.video_path)
        out_name = os.path.basename(args.video_path) + ".predictions.slp"
        output_path = os.path.join(out_dir, out_name)
    elif args.labels:
        out_dir = os.path.dirname(args.labels)
        out_name = os.path.basename(args.labels) + ".predictions.slp"
        output_path = os.path.join(out_dir, out_name)
    else:
        # We shouldn't ever get here but if we do, just save in working dir.
        output_path = "predictions.slp"

    labels = Labels(labeled_frames=predicted_frames, provenance=prediction_metadata)

    print(f"Saving: {output_path}")
    Labels.save_file(labels, output_path)
Example #14
0
    def read(
        cls,
        file: FileHandle,
        video_path: str,
        skeleton_path: str,
        *args,
        **kwargs,
    ) -> Labels:
        f = file.file

        video = Video.from_filename(video_path)
        skeleton_data = pd.read_csv(skeleton_path, header=0)

        skeleton = Skeleton()
        skeleton.add_nodes(skeleton_data["name"])
        nodes = skeleton.nodes

        for name, parent, swap in skeleton_data.itertuples(index=False,
                                                           name=None):
            if parent is not np.nan:
                skeleton.add_edge(parent, name)

        lfs = []

        pose_matrix = f["pose"][:]

        track_count, frame_count, node_count, _ = pose_matrix.shape

        tracks = [Track(0, f"Track {i}") for i in range(track_count)]
        for frame_idx in range(frame_count):
            lf_instances = []
            for track_idx in range(track_count):
                points_array = pose_matrix[track_idx, frame_idx, :, :]
                points = dict()
                for p in range(len(points_array)):
                    x, y, score = points_array[p]
                    points[nodes[p]] = Point(x, y)  # TODO: score

                inst = Instance(skeleton=skeleton,
                                track=tracks[track_idx],
                                points=points)
                lf_instances.append(inst)
            lfs.append(
                LabeledFrame(video,
                             frame_idx=frame_idx,
                             instances=lf_instances))

        return Labels(labeled_frames=lfs)
Example #15
0
def demo_training_dialog():
    app = QtWidgets.QApplication([])

    filename = "tests/data/json_format_v1/centered_pair.json"
    labels = Labels.load_file(filename)
    win = LearningDialog("inference", labels_filename=filename, labels=labels)

    win.frame_selection = {"clip": {labels.videos[0]: (1, 2, 3, 4)}}
    # win.training_editor_widget.set_fields_from_key_val_dict({
    #     "_backbone_name": "unet",
    #     "_heads_name": "centered_instance",
    # })
    #
    # win.training_editor_widget.form_widgets["model"].set_field_enabled("_heads_name", False)

    win.show()
    app.exec_()
Example #16
0
def test_inference_merging():
    skeleton = Skeleton()
    video = Video(backend=MediaVideo)
    lf_user_only = LabeledFrame(video=video,
                                frame_idx=0,
                                instances=[Instance(skeleton=skeleton)])
    lf_pred_only = LabeledFrame(
        video=video,
        frame_idx=1,
        instances=[PredictedInstance(skeleton=skeleton)])
    lf_both = LabeledFrame(
        video=video,
        frame_idx=2,
        instances=[
            Instance(skeleton=skeleton),
            PredictedInstance(skeleton=skeleton)
        ],
    )
    labels = Labels([lf_user_only, lf_pred_only, lf_both])

    task = runners.InferenceTask(
        trained_job_paths=None,
        inference_params=None,
        labels=labels,
        results=[
            LabeledFrame(
                video=labels.video,
                frame_idx=2,
                instances=[
                    PredictedInstance(skeleton=skeleton),
                    PredictedInstance(skeleton=skeleton),
                ],
            )
        ],
    )
    task.merge_results()

    assert len(labels) == 3
    assert labels[0].frame_idx == 0
    assert labels[0].has_user_instances
    assert labels[1].frame_idx == 1
    assert labels[1].has_predicted_instances
    assert labels[2].frame_idx == 2
    assert len(labels[2].user_instances) == 1
    assert len(labels[2].predicted_instances) == 2
Example #17
0
    def skeleton(self):
        # cache skeleton so we only search once
        if self._skeleton is None and not self._tried_finding_skeleton:

            # if skeleton was saved in config, great!
            if self.config.data.labels.skeletons:
                self._skeleton = self.config.data.labels.skeletons[0]

            # otherwise try loading it from validation labels (much slower!)
            else:
                filename = self._get_file_path(f"labels_gt.val.slp")
                if filename is not None:
                    val_labels = Labels.load_file(filename)
                    if val_labels.skeletons:
                        self._skeleton = val_labels.skeletons[0]

            # don't try loading again (needed in case it's still None)
            self._tried_finding_skeleton = True

        return self._skeleton
Example #18
0
    def __init__(
        self,
        mode: Text,
        labels_filename: Text,
        labels: Optional[Labels] = None,
        skeleton: Optional["Skeleton"] = None,
        *args,
        **kwargs,
    ):
        super(LearningDialog, self).__init__()

        if labels is None:
            labels = Labels.load_file(labels_filename)

        if skeleton is None and labels.skeletons:
            skeleton = labels.skeletons[0]

        self.mode = mode
        self.labels_filename = labels_filename
        self.labels = labels
        self.skeleton = skeleton

        self._frame_selection = None

        self.current_pipeline = ""

        self.tabs = dict()
        self.shown_tab_names = []

        self._cfg_getter = configs.TrainingConfigsGetter.make_from_labels_filename(
            labels_filename=self.labels_filename)

        # Layout for buttons
        buttons = QtWidgets.QDialogButtonBox()
        self.cancel_button = buttons.addButton(
            QtWidgets.QDialogButtonBox.Cancel)
        self.save_button = buttons.addButton(
            "Save configuration files...",
            QtWidgets.QDialogButtonBox.ApplyRole)
        self.run_button = buttons.addButton(
            "Run", QtWidgets.QDialogButtonBox.AcceptRole)

        buttons_layout = QtWidgets.QHBoxLayout()
        buttons_layout.addWidget(buttons, alignment=QtCore.Qt.AlignTop)

        buttons_layout_widget = QtWidgets.QWidget()
        buttons_layout_widget.setLayout(buttons_layout)

        self.pipeline_form_widget = TrainingPipelineWidget(mode=mode,
                                                           skeleton=skeleton)
        if mode == "training":
            tab_label = "Training Pipeline"
        elif mode == "inference":
            # self.pipeline_form_widget = InferencePipelineWidget()
            tab_label = "Inference Pipeline"
        else:
            raise ValueError(f"Invalid LearningDialog mode: {mode}")

        self.tab_widget = QtWidgets.QTabWidget()

        self.tab_widget.addTab(self.pipeline_form_widget, tab_label)
        self.make_tabs()

        self.message_widget = QtWidgets.QLabel("")

        # Layout for entire dialog
        layout = QtWidgets.QVBoxLayout()
        layout.addWidget(self.tab_widget)
        layout.addWidget(self.message_widget)
        layout.addWidget(buttons_layout_widget)

        self.setLayout(layout)

        # Default to most recently trained pipeline (if there is one)
        self.set_pipeline_from_most_recent()

        # Connect functions to update pipeline tabs when pipeline changes
        self.pipeline_form_widget.updatePipeline.connect(self.set_pipeline)
        self.pipeline_form_widget.emitPipeline()

        self.connect_signals()

        # Connect actions for buttons
        buttons.accepted.connect(self.run)
        buttons.rejected.connect(self.reject)
        buttons.clicked.connect(self.on_button_click)

        # Connect button for previewing the training data
        if "_view_datagen" in self.pipeline_form_widget.buttons:
            self.pipeline_form_widget.buttons["_view_datagen"].clicked.connect(
                self.view_datagen)
Example #19
0
    def read(
        cls,
        file: FileHandle,
        video: Union[Video, str],
        *args,
        **kwargs,
    ) -> Labels:
        connect_adj_nodes = False

        if video is None:
            raise ValueError(
                "Cannot read analysis hdf5 if no video specified.")

        if not isinstance(video, Video):
            video = Video.from_filename(video)

        f = file.file
        tracks_matrix = f["tracks"][:].T
        track_names_list = f["track_names"][:].T
        node_names_list = f["node_names"][:].T

        # shape: frames * nodes * 2 * tracks
        frame_count, node_count, _, track_count = tracks_matrix.shape

        tracks = [
            Track(0, track_name.decode()) for track_name in track_names_list
        ]

        skeleton = Skeleton()
        last_node_name = None
        for node_name in node_names_list:
            node_name = node_name.decode()
            skeleton.add_node(node_name)
            if connect_adj_nodes and last_node_name:
                skeleton.add_edge(last_node_name, node_name)
            last_node_name = node_name

        frames = []
        for frame_idx in range(frame_count):
            instances = []
            for track_idx in range(track_count):
                points = tracks_matrix[frame_idx, ..., track_idx]
                if not np.all(np.isnan(points)):
                    point_scores = np.ones(len(points))
                    # make everything a PredictedInstance since the usual use
                    # case is to export predictions for analysis
                    instances.append(
                        PredictedInstance.from_arrays(
                            points=points,
                            point_confidences=point_scores,
                            skeleton=skeleton,
                            track=tracks[track_idx],
                            instance_score=1,
                        ))
            if instances:
                frames.append(
                    LabeledFrame(video=video,
                                 frame_idx=frame_idx,
                                 instances=instances))

        return Labels(labeled_frames=frames)
Example #20
0
            self.context.labels.remove_instance(lf, inst, in_transaction=True)
            if not lf.instances:
                self.context.labels.remove(lf)

        # Update caches since we skipped doing this after each deletion
        self.context.labels.update_cache()

        # Log update
        self.context.changestack_push("delete instances")


if __name__ == "__main__":

    app = QtWidgets.QApplication([])

    from sleap import Labels
    from sleap.gui.commands import CommandContext

    labels = Labels.load_file(
        "tests/data/json_format_v2/centered_pair_predictions.json")
    context = CommandContext.from_labels(labels)
    context.state["frame_idx"] = 123
    context.state["video"] = labels.videos[0]
    context.state["has_frame_range"] = True
    context.state["frame_range"] = (10, 20)

    win = DeleteDialog(context=context)
    win.show()

    app.exec_()
Example #21
0
    def read(
        cls,
        file: FileHandle,
        video_search: Union[Callable, List[Text], None] = None,
        match_to: Optional[Labels] = None,
        *args,
        **kwargs,
    ) -> Labels:
        pass
        """
        Deserialize JSON file as new :class:`Labels` instance.

        Args:
            filename: Path to JSON file.
            video_callback: A callback function that which can modify
                video paths before we try to create the corresponding
                :class:`Video` objects. Usually you'll want to pass
                a callback created by :meth:`make_video_callback`
                or :meth:`make_gui_video_callback`.
                Alternately, if you pass a list of strings we'll construct a
                non-gui callback with those strings as the search paths.
            match_to: If given, we'll replace particular objects in the
                data dictionary with *matching* objects in the match_to
                :class:`Labels` object. This ensures that the newly
                instantiated :class:`Labels` can be merged without
                duplicate matching objects (e.g., :class:`Video` objects ).
        Returns:
            A new :class:`Labels` object.
        """

        tmp_dir = None
        filename = file.filename

        # Check if the file is a zipfile for not.
        if zipfile.is_zipfile(filename):

            # Make a tmpdir, located in the directory that the file exists, to unzip
            # its contents.
            tmp_dir = os.path.join(
                os.path.dirname(filename),
                f"tmp_{os.getpid()}_{os.path.basename(filename)}",
            )
            if os.path.exists(tmp_dir):
                shutil.rmtree(tmp_dir, ignore_errors=True)
            try:
                os.mkdir(tmp_dir)
            except FileExistsError:
                pass

            # tmp_dir = tempfile.mkdtemp(dir=os.path.dirname(filename))

            try:

                # Register a cleanup routine that deletes the tmpdir on program exit
                # if something goes wrong. The True is for ignore_errors
                atexit.register(shutil.rmtree, tmp_dir, True)

                # Uncompress the data into the directory
                shutil.unpack_archive(filename, extract_dir=tmp_dir)

                # We can now open the JSON file, save the zip file and
                # replace file with the first JSON file we find in the archive.
                json_files = [
                    os.path.join(tmp_dir, file) for file in os.listdir(tmp_dir)
                    if file.endswith(".json")
                ]

                if len(json_files) == 0:
                    raise ValueError(
                        f"No JSON file found inside {filename}. Are you sure this is a valid sLEAP dataset."
                    )

                filename = json_files[0]

            except Exception as ex:
                # If we had problems, delete the temp directory and reraise the exception.
                shutil.rmtree(tmp_dir, ignore_errors=True)
                raise

        # Open and parse the JSON in filename
        with open(filename, "r") as file:

            # FIXME: Peek into the json to see if there is version string.
            # We do this to tell apart old JSON data from leap_dev vs the
            # newer format for sLEAP.
            json_str = file.read()
            dicts = json_loads(json_str)

            # If we have a version number, then it is new sLEAP format
            if "version" in dicts:

                # Cache the working directory.
                cwd = os.getcwd()
                # Replace local video paths (for imagestore)
                if tmp_dir:
                    for vid in dicts["videos"]:
                        vid["backend"]["filename"] = os.path.join(
                            tmp_dir, vid["backend"]["filename"])

                # Use the video_callback for finding videos with broken paths:

                # 1. Accept single string as video search path
                if isinstance(video_search, str):
                    video_search = [video_search]

                # 2. Accept list of strings as video search paths
                if hasattr(video_search, "__iter__"):
                    # If the callback is an iterable, then we'll expect it to be a
                    # list of strings and build a non-gui callback with those as
                    # the search paths.
                    # When path is to a file, use the path of parent directory.
                    search_paths = [
                        os.path.dirname(path) if os.path.isfile(path) else path
                        for path in video_search
                    ]

                    # Make the search function from list of paths
                    video_search = Labels.make_video_callback(search_paths)

                # 3. Use the callback function (either given as arg or build from paths)
                if callable(video_search):
                    abort = video_search(dicts["videos"])
                    if abort:
                        raise FileNotFoundError

                # Try to load the labels filename.
                try:
                    labels = cls.from_json_data(dicts, match_to=match_to)

                except FileNotFoundError:

                    # FIXME: We are going to the labels JSON that has references to
                    # video files. Lets change directory to the dirname of the json file
                    # so that relative paths will be from this directory. Maybe
                    # it is better to feed the dataset dirname all the way down to
                    # the Video object. This seems like less coupling between classes
                    # though.
                    if os.path.dirname(filename) != "":
                        os.chdir(os.path.dirname(filename))

                    # Try again
                    labels = cls.from_json_data(dicts, match_to=match_to)

                except Exception as ex:
                    # Ok, we give up, where the hell are these videos!
                    raise  # Re-raise.
                finally:
                    os.chdir(
                        cwd)  # Make sure to change back if we have problems.

                return labels

            else:
                frames = load_labels_json_old(data_path=filename,
                                              parsed_json=dicts)
                return Labels(frames)
Example #22
0
    def from_json_data(cls,
                       data: Union[str, dict],
                       match_to: Optional["Labels"] = None) -> "Labels":
        """
        Create instance of class from data in dictionary.

        Method is used by other methods that load from JSON.

        Args:
            data: Dictionary, deserialized from JSON.
            match_to: If given, we'll replace particular objects in the
                data dictionary with *matching* objects in the match_to
                :class:`Labels` object. This ensures that the newly
                instantiated :class:`Labels` can be merged without
                duplicate matching objects (e.g., :class:`Video` objects ).
        Returns:
            A new :class:`Labels` object.
        """

        # Parse the json string if needed.
        if type(data) is str:
            dicts = json_loads(data)
        else:
            dicts = data

        dicts["tracks"] = dicts.get(
            "tracks", [])  # don't break if json doesn't include tracks

        # First, deserialize the skeletons, videos, and nodes lists.
        # The labels reference these so we will need them while deserializing.
        nodes = cattr.structure(dicts["nodes"], List[Node])

        idx_to_node = {i: nodes[i] for i in range(len(nodes))}
        skeletons = Skeleton.make_cattr(idx_to_node).structure(
            dicts["skeletons"], List[Skeleton])
        videos = Video.cattr().structure(dicts["videos"], List[Video])

        try:
            # First try unstructuring tuple (newer format)
            track_cattr = cattr.Converter(
                unstruct_strat=cattr.UnstructureStrategy.AS_TUPLE)
            tracks = track_cattr.structure(dicts["tracks"], List[Track])
        except:
            # Then try unstructuring dict (older format)
            try:
                tracks = cattr.structure(dicts["tracks"], List[Track])
            except:
                raise ValueError("Unable to load tracks as tuple or dict!")

        # if we're given a Labels object to match, use its objects when they match
        if match_to is not None:
            for idx, sk in enumerate(skeletons):
                for old_sk in match_to.skeletons:
                    if sk.matches(old_sk):
                        # use nodes from matched skeleton
                        for (node, match_node) in zip(sk.nodes, old_sk.nodes):
                            node_idx = nodes.index(node)
                            nodes[node_idx] = match_node
                        # use skeleton from match
                        skeletons[idx] = old_sk
                        break
            for idx, vid in enumerate(videos):
                for old_vid in match_to.videos:

                    # Try to match videos using either their current or source filename
                    # if available.
                    old_vid_paths = [old_vid.filename]
                    if getattr(old_vid.backend, "has_embedded_images", False):
                        old_vid_paths.append(
                            old_vid.backend._source_video.filename)

                    new_vid_paths = [vid.filename]
                    if getattr(vid.backend, "has_embedded_images", False):
                        new_vid_paths.append(
                            vid.backend._source_video.filename)

                    is_match = False
                    for old_vid_path in old_vid_paths:
                        for new_vid_path in new_vid_paths:
                            if old_vid_path == new_vid_path or weak_filename_match(
                                    old_vid_path, new_vid_path):
                                is_match = True
                                videos[idx] = old_vid
                                break
                        if is_match:
                            break
                    if is_match:
                        break

        suggestions = []
        if "suggestions" in dicts:
            suggestions_cattr = cattr.Converter()
            suggestions_cattr.register_structure_hook(
                Video, lambda x, type: videos[int(x)])
            try:
                suggestions = suggestions_cattr.structure(
                    dicts["suggestions"], List[SuggestionFrame])
            except Exception as e:
                print("Error while loading suggestions (1)")
                print(e)

                try:
                    # Convert old suggestion format to new format.
                    # Old format: {video: list of frame indices}
                    # New format: [SuggestionFrames]
                    old_suggestions = suggestions_cattr.structure(
                        dicts["suggestions"], Dict[Video, List])
                    for video in old_suggestions.keys():
                        suggestions.extend([
                            SuggestionFrame(video, idx)
                            for idx in old_suggestions[video]
                        ])
                except Exception as e:
                    print("Error while loading suggestions (2)")
                    print(e)
                    pass

        if "negative_anchors" in dicts:
            negative_anchors_cattr = cattr.Converter()
            negative_anchors_cattr.register_structure_hook(
                Video, lambda x, type: videos[int(x)])
            negative_anchors = negative_anchors_cattr.structure(
                dicts["negative_anchors"], Dict[Video, List])
        else:
            negative_anchors = dict()

        if "provenance" in dicts:
            provenance = dicts["provenance"]
        else:
            provenance = dict()

        # If there is actual labels data, get it.
        if "labels" in dicts:
            label_cattr = make_instance_cattr()
            label_cattr.register_structure_hook(
                Skeleton, lambda x, type: skeletons[int(x)])
            label_cattr.register_structure_hook(Video,
                                                lambda x, type: videos[int(x)])
            label_cattr.register_structure_hook(
                Node, lambda x, type: x
                if isinstance(x, Node) else nodes[int(x)])
            label_cattr.register_structure_hook(
                Track, lambda x, type: None if x is None else tracks[int(x)])

            labels = label_cattr.structure(dicts["labels"], List[LabeledFrame])
        else:
            labels = []

        return Labels(
            labeled_frames=labels,
            videos=videos,
            skeletons=skeletons,
            nodes=nodes,
            suggestions=suggestions,
            negative_anchors=negative_anchors,
            tracks=tracks,
            provenance=provenance,
        )
Example #23
0
def evaluate_model(
    cfg: TrainingJobConfig,
    labels_reader: LabelsReader,
    model: Model,
    save: bool = True,
    split_name: Text = "test",
) -> Tuple[Labels, Dict[Text, Any]]:
    """Evaluate a trained model and save metrics and predictions.

    Args:
        cfg: The `TrainingJobConfig` associated with the model.
        labels_reader: A `LabelsReader` pipeline generator that reads the ground truth
            data to evaluate.
        model: The `sleap.nn.model.Model` instance to evaluate.
        save: If True, save the predictions and metrics to the model folder.
        split_name: String name to append to the saved filenames.

    Returns:
        A tuple of `(labels_pr, metrics)`.

        `labels_pr` will contain the predicted labels.

        `metrics` will contain the evaluated metrics given the predictions, or None if
        the metrics failed to be computed.
    """
    # Setup predictor for evaluation.
    head_config = cfg.model.heads.which_oneof()
    if isinstance(head_config, CentroidsHeadConfig):
        predictor = TopDownPredictor(
            centroid_config=cfg,
            centroid_model=model,
            confmap_config=None,
            confmap_model=None,
        )
    elif isinstance(head_config, CenteredInstanceConfmapsHeadConfig):
        predictor = TopDownPredictor(
            centroid_config=None,
            centroid_model=None,
            confmap_config=cfg,
            confmap_model=model,
        )
    elif isinstance(head_config, MultiInstanceConfig):
        predictor = sleap.nn.inference.BottomUpPredictor(bottomup_config=cfg,
                                                         bottomup_model=model)
    elif isinstance(head_config, SingleInstanceConfmapsHeadConfig):
        predictor = sleap.nn.inference.SingleInstancePredictor(
            confmap_config=cfg, confmap_model=model)
    else:
        raise ValueError("Unrecognized model type:", head_config)

    # Predict.
    labels_pr = predictor.predict(labels_reader, make_labels=True)

    # Compute metrics.
    try:
        metrics = evaluate(labels_reader.labels, labels_pr)
    except:
        logger.warning("Failed to compute metrics.")
        metrics = None

    # Save.
    if save:
        labels_pr_path = os.path.join(cfg.outputs.run_path,
                                      f"labels_pr.{split_name}.slp")
        Labels.save_file(labels_pr, labels_pr_path)
        logger.info("Saved predictions: %s", labels_pr_path)

        if metrics is not None:
            metrics_path = os.path.join(cfg.outputs.run_path,
                                        f"metrics.{split_name}.npz")
            np.savez_compressed(metrics_path, **{"metrics": metrics})
            logger.info("Saved metrics: %s", metrics_path)

    if metrics is not None:
        logger.info("OKS mAP: %f", metrics["oks_voc.mAP"])

    return labels_pr, metrics
# This script automatically deletes predictions with scores below 10%.
# It also names each track after the uploaded file, but may need to be altered based on circumstances (lines 36-37).

from sleap import Labels
import sys

SCORE_THRESHOLD = .1

filename = sys.argv[1]
out_filename = str(filename)+'_highscores_trackName.h5'

labels = Labels.load_file(filename)

lf_inst_list = []

# Find the (frame, instance) pairs with score below threshold
for frame in labels:
    for instance in frame:
        if hasattr(instance, "score"):
            if instance.score is not None:
                if instance.score < SCORE_THRESHOLD:
                    lf_inst_list.append((frame, instance))

if lf_inst_list:

    print(f"Removing {len(lf_inst_list)} instances...")

    # Remove each of the instances
    for frame, instance in lf_inst_list:
        labels.remove_instance(frame, instance, in_transaction=True)
Example #25
0
    def read(
        cls,
        file: FileHandle,
        img_dir: str,
        use_missing_gui: bool = False,
        *args,
        **kwargs,
    ) -> Labels:

        dicts = file.json

        # Make skeletons from "categories"
        skeleton_map = dict()
        for category in dicts["categories"]:
            skeleton = Skeleton(name=category["name"])
            skeleton_id = category["id"]
            node_names = category["keypoints"]
            skeleton.add_nodes(node_names)

            try:
                for src_idx, dst_idx in category["skeleton"]:
                    skeleton.add_edge(node_names[src_idx], node_names[dst_idx])
            except IndexError as e:
                # According to the COCO data format specifications[^1], the edges
                # are supposed to be 1-indexed. But in some of their own
                # dataset the edges are 1-indexed! So we'll try.
                # [1]: http://cocodataset.org/#format-data

                # Clear any edges we already created using 0-indexing
                skeleton.clear_edges()

                # Add edges
                for src_idx, dst_idx in category["skeleton"]:
                    skeleton.add_edge(node_names[src_idx - 1], node_names[dst_idx - 1])

            skeleton_map[skeleton_id] = skeleton

        # Make videos from "images"

        # Remove images that aren't referenced in the annotations
        img_refs = [annotation["image_id"] for annotation in dicts["annotations"]]
        dicts["images"] = list(filter(lambda im: im["id"] in img_refs, dicts["images"]))

        # Key in JSON file should be "file_name", but sometimes it's "filename",
        # so we have to check both.
        img_filename_key = "file_name"
        if img_filename_key not in dicts["images"][0].keys():
            img_filename_key = "filename"

        # First add the img_dir to each image filename
        img_paths = [
            os.path.join(img_dir, image[img_filename_key]) for image in dicts["images"]
        ]

        # See if there are any missing files
        img_missing = [not os.path.exists(path) for path in img_paths]

        if sum(img_missing):
            if use_missing_gui:
                okay = MissingFilesDialog(img_paths, img_missing).exec_()

                if not okay:
                    return None
            else:
                raise FileNotFoundError(
                    f"Images for COCO dataset could not be found in {img_dir}."
                )

        # Update the image paths (with img_dir or user selected path)
        for image, path in zip(dicts["images"], img_paths):
            image[img_filename_key] = path

        # Create the video objects for the image files
        image_video_map = dict()

        vid_id_video_map = dict()
        for image in dicts["images"]:
            image_id = image["id"]
            image_filename = image[img_filename_key]

            # Sometimes images have a vid_id which links multiple images
            # together as one video. If so, we'll use that as the video key.
            # But if there isn't a vid_id, we'll treat each images as a
            # distinct video and use the image id as the video id.
            vid_id = image.get("vid_id", image_id)

            if vid_id not in vid_id_video_map:
                kwargs = dict(filenames=[image_filename])
                for key in ("width", "height"):
                    if key in image:
                        kwargs[key] = image[key]

                video = Video.from_image_filenames(**kwargs)
                vid_id_video_map[vid_id] = video
                frame_idx = 0
            else:
                video = vid_id_video_map[vid_id]
                frame_idx = video.num_frames
                video.backend.filenames.append(image_filename)

            image_video_map[image_id] = (video, frame_idx)

        # Make instances from "annotations"
        lf_map = dict()
        track_map = dict()
        for annotation in dicts["annotations"]:
            skeleton = skeleton_map[annotation["category_id"]]
            image_id = annotation["image_id"]
            video, frame_idx = image_video_map[image_id]
            keypoints = np.array(annotation["keypoints"], dtype="int").reshape(-1, 3)

            track = None
            if "track_id" in annotation:
                track_id = annotation["track_id"]
                if track_id not in track_map:
                    track_map[track_id] = Track(frame_idx, str(track_id))
                track = track_map[track_id]

            points = dict()
            any_visible = False
            for i in range(len(keypoints)):
                node = skeleton.nodes[i]
                x, y, flag = keypoints[i]

                if flag == 0:
                    # node not labeled for this instance
                    continue

                is_visible = flag == 2
                any_visible = any_visible or is_visible
                points[node] = Point(x, y, is_visible)

            if points:
                # If none of the points had 2 has the "visible" flag, we'll
                # assume this incorrect and just mark all as visible.
                if not any_visible:
                    for point in points.values():
                        point.visible = True

                inst = Instance(skeleton=skeleton, points=points, track=track)

                if image_id not in lf_map:
                    lf_map[image_id] = LabeledFrame(video, frame_idx)

                lf_map[image_id].insert(0, inst)

        return Labels(labeled_frames=list(lf_map.values()))
Example #26
0
    def read(
        cls, file: FileHandle, gui: bool = True, *args, **kwargs,
    ):
        filename = file.filename

        mat_contents = sio.loadmat(filename)

        box_path = cls._unwrap_mat_scalar(mat_contents["boxPath"])

        # If the video file isn't found, try in the same dir as the mat file
        if not os.path.exists(box_path):
            file_dir = os.path.dirname(filename)
            box_path_name = box_path.split("\\")[-1]  # assume windows path
            box_path = os.path.join(file_dir, box_path_name)

        if not os.path.exists(box_path):
            if gui:
                video_paths = [box_path]
                missing = [True]
                okay = MissingFilesDialog(video_paths, missing).exec_()

                if not okay or missing[0]:
                    return

                box_path = video_paths[0]
            else:
                # Ignore missing videos if not loading from gui
                box_path = ""

        if os.path.exists(box_path):
            vid = Video.from_hdf5(
                dataset="box", filename=box_path, input_format="channels_first"
            )
        else:
            vid = None

        nodes_ = mat_contents["skeleton"]["nodes"]
        edges_ = mat_contents["skeleton"]["edges"]
        points_ = mat_contents["positions"]

        edges_ = edges_ - 1  # convert matlab 1-indexing to python 0-indexing

        nodes = cls._unwrap_mat_array(nodes_)
        edges = cls._unwrap_mat_array(edges_)

        nodes = list(map(str, nodes))  # convert np._str to str

        sk = Skeleton(name=filename)
        sk.add_nodes(nodes)
        for edge in edges:
            sk.add_edge(source=nodes[edge[0]], destination=nodes[edge[1]])

        labeled_frames = []
        node_count, _, frame_count = points_.shape

        for i in range(frame_count):
            new_inst = Instance(skeleton=sk)
            for node_idx, node in enumerate(nodes):
                x = points_[node_idx][0][i]
                y = points_[node_idx][1][i]
                new_inst[node] = Point(x, y)
            if len(new_inst.points):
                new_frame = LabeledFrame(video=vid, frame_idx=i)
                new_frame.instances = (new_inst,)
                labeled_frames.append(new_frame)

        labels = Labels(labeled_frames=labeled_frames, videos=[vid], skeletons=[sk])

        return labels
Example #27
0
    """Returns mean of aligned points for instances."""
    points = get_instances_points(instances)

    node_a, node_b = get_most_stable_node_pair(points, min_dist=4.0)

    aligned = align_instances(points, node_a=node_a, node_b=node_b)
    points_mean, points_std = get_mean_and_std_for_points(aligned)
    return points_mean


if __name__ == "__main__":
    # filename = "tests/data/json_format_v2/centered_pair_predictions.json"
    # filename = "/Volumes/fileset-mmurthy/shruthi/code/sleap_expts/preds/screen_all.5pts_tmp_augment_200122/191210_102108_18159112_rig3_2.preds.h5"
    filename = "/Volumes/fileset-mmurthy/talmo/wt_gold_labeling/100919.sleap_wt_gold.13pt_init.n=288.junyu.h5"

    labels = Labels.load_file(filename)

    points = get_instances_points(labels.instances())
    get_stable_node_pairs(points, np.array(labels.skeletons[0].node_names))

    # import time
    #
    # t0 = time.time()
    # labels.add_instance(
    #     frame=labels.find_first(video=labels.videos[0]),
    #     instance=make_mean_instance(align_instances(points, 12, 0))
    # )
    # print(labels.find_first(video=labels.videos[0]))
    # print("time", time.time() - t0)
    #
    # Labels.save_file(labels, "mean.h5")
Example #28
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("input_path", help="Path to input file.")
    parser.add_argument(
        "-o", "--output", default="", help="Path to output file (optional)."
    )
    parser.add_argument(
        "--format",
        default="slp",
        help="Output format. Default ('slp') is SLEAP dataset; "
        "'analysis' results in analysis.h5 file; "
        "'h5' or 'json' results in SLEAP dataset "
        "with specified file format.",
    )
    parser.add_argument(
        "--video", default="", help="Path to video (if needed for conversion)."
    )

    args = parser.parse_args()

    video_callback = Labels.make_video_callback([os.path.dirname(args.input_path)])
    try:
        labels = Labels.load_file(args.input_path, video_search=video_callback)
    except TypeError:
        print("Input file isn't SLEAP dataset so attempting other importers...")
        from sleap.io.format import read

        video_path = args.video if args.video else None

        labels = read(
            args.input_path,
            for_object="labels",
            as_format="*",
            video_search=video_callback,
            video=video_path,
        )

    if args.format == "analysis":
        from sleap.info.write_tracking_h5 import main as write_analysis

        if args.output:
            output_path = args.output
        else:
            output_path = args.input_path
            output_path = re.sub("(\.json(\.zip)?|\.h5|\.slp)$", "", output_path)
            output_path = output_path + ".analysis.h5"

        write_analysis(labels, output_path=output_path, all_frames=True)

    elif args.output:
        print(f"Output SLEAP dataset: {args.output}")
        Labels.save_file(labels, args.output)

    elif args.format in ("slp", "h5", "json"):
        output_path = f"{args.input_path}.{args.format}"
        print(f"Output SLEAP dataset: {output_path}")
        Labels.save_file(labels, output_path)

    else:
        print("You didn't specify how to convert the file.")
        print(args)