Esempio n. 1
0
def test_assembly():
    ass = inferenceutils.Assembly(3)
    assert len(ass) == 0

    j1 = inferenceutils.Joint((1, 1), label=0)
    j2 = inferenceutils.Joint((1, 1), label=1)
    assert ass.add_link(inferenceutils.Link(j1, j2), store_dict=True)
    assert len(ass) == 2
    assert ass.data[j2.label, 0] == 1
    assert ass.data[j2.label, -1] == -1
    assert ass.area == 0
    assert ass.intersection_with(ass) == 1.0
    assert np.all(np.isnan(ass._dict["data"]))

    ass.remove_joint(j2)
    assert len(ass) == 1
    assert np.all(np.isnan(ass.data[j2.label]))

    ass2 = inferenceutils.Assembly(2)
    ass2.add_link(inferenceutils.Link(j1, j2))
    with pytest.raises(ValueError):
        _ = ass + ass2
    ass2.remove_joint(j1)
    assert ass2 not in ass
    ass3 = ass + ass2
    assert len(ass3) == 2
def find_outliers_in_raw_data(
        config,
        pickle_file,
        video_file,
        pcutoff=0.1,
        percentiles=(5, 95),
        with_annotations=True,
        extraction_algo="kmeans",
):
    """
    Extract outlier frames from either raw detections or assemblies of multiple animals.

    Parameter
    ----------
    config : str
        Absolute path to the project config.yaml.

    pickled_file : str
        Path to a *_full.pickle or *_assemblies.pickle.

    video_file : str
        Path to the corresponding video file for frame extraction.

    pcutoff : float, optional (default=0.1)
        Detection confidence threshold below which frames are flagged as
        containing outliers. Only considered if raw detections are passed in.

    percentiles : tuple, optional (default=(5, 95))
        Assemblies are considered outliers if their areas are beyond the 5th
        and 95th percentiles. Must contain a lower and upper bound.

    with_annotations : bool, optional (default=True)
        If true, extract frames and the corresponding network predictions.
        Otherwise, only the frames are extracted.

    extraction_algo : string, optional (default="kmeans")
        Outlier detection algorithm. Must be either ``uniform`` or ``kmeans``.

    """
    if extraction_algo not in ("kmeans", "uniform"):
        raise ValueError(
            f"Unsupported extraction algorithm {extraction_algo}.")

    video_name = Path(video_file).stem
    pickle_name = Path(pickle_file).stem
    if not pickle_name.startswith(video_name):
        raise ValueError("Video and pickle files do not match.")

    with open(pickle_file, "rb") as file:
        data = pickle.load(file)
    if pickle_file.endswith("_full.pickle"):
        inds, data = find_outliers_in_raw_detections(data, threshold=pcutoff)
        with_annotations = False
    elif pickle_file.endswith("_assemblies.pickle"):
        assemblies = dict()
        for k, lst in data.items():
            if k == "single":
                continue
            ass = []
            for vals in lst:
                a = inferenceutils.Assembly(len(vals))
                a.data = vals
                ass.append(a)
            assemblies[k] = ass
        inds = inferenceutils.find_outlier_assemblies(assemblies,
                                                      qs=percentiles)
    else:
        raise IOError(f"Raw data file {pickle_file} could not be parsed.")

    cfg = auxiliaryfunctions.read_config(config)
    ExtractFramesbasedonPreselection(
        inds,
        extraction_algo,
        data,
        video=video_file,
        cfg=cfg,
        config=config,
        savelabeled=False,
        with_annotations=with_annotations,
    )