예제 #1
0
def test_stitcher_with_identity(real_tracklets):
    # Add fake IDs
    for i in range(3):
        tracklet = real_tracklets[i]
        for v in tracklet.values():
            v[:, -1] = i
    stitcher = TrackletStitcher.from_dict_of_dict(real_tracklets, n_tracks=3)
    tracklets = sorted(stitcher, key=lambda t: t.identity)
    assert all(tracklet.identity == i for i, tracklet in enumerate(tracklets))

    # Split all tracklets in half
    tracklets = [
        t for track in stitcher for t in stitcher.split_tracklet(track, [25])
    ]
    stitcher = TrackletStitcher(tracklets, n_tracks=3)
    assert len(stitcher) == 6

    stitcher.build_graph()
    weight = stitcher.G.edges[('0out', '3in')]['weight']

    def weight_func(t1, t2):
        w = 0.01 if t1.identity == t2.identity else 1
        return w * t1.distance_to(t2)

    stitcher.build_graph(weight_func=weight_func)
    assert stitcher.G.number_of_edges() == 27
    new_weight = stitcher.G.edges[('0out', '3in')]['weight']
    assert new_weight == weight // 100

    stitcher.stitch()
    assert len(stitcher.tracks) == 3
    assert all(len(track) == 50 for track in stitcher.tracks)
    assert all(0.998 <= track.likelihood <= 1 for track in stitcher.tracks)
    tracks = sorted(stitcher.tracks, key=lambda t: t.identity)
    assert all(track.identity == i for i, track in enumerate(tracks))
예제 #2
0
def test_stitcher_montblanc(real_tracklets_montblanc):
    stitcher = TrackletStitcher.from_dict_of_dict(
        real_tracklets_montblanc,
        n_tracks=3,
    )
    assert len(stitcher) == 5
    assert all(tracklet.is_continuous for tracklet in stitcher.tracklets)
    assert all(tracklet.identity == -1 for tracklet in stitcher.tracklets)
    assert len(stitcher.residuals) == 1
    assert len(stitcher.residuals[0]) == 2
    assert stitcher.compute_max_gap(stitcher.tracklets) == 5

    stitcher.build_graph()
    assert stitcher.G.number_of_edges() == 18
    weights = [w for *_, w in stitcher.G.edges.data("weight") if w]
    assert weights == [2453, 24498, 5428]

    stitcher.stitch()
    assert len(stitcher.tracks) == 3
    assert all(len(track) >= 176 for track in stitcher.tracks)
    assert all(0.996 <= track.likelihood <= 1 for track in stitcher.tracks)

    df_gt = pd.read_hdf('tests/data/montblanc_tracks.h5')
    df = stitcher.format_df()
    np.testing.assert_equal(df.to_numpy(), df_gt.to_numpy())
예제 #3
0
def test_stitcher_real(tmpdir_factory, real_tracklets):
    stitcher = TrackletStitcher.from_dict_of_dict(real_tracklets, n_tracks=3)
    assert len(stitcher) == 3
    assert all(tracklet.is_continuous for tracklet in stitcher.tracklets)
    assert not stitcher.residuals
    assert stitcher.compute_max_gap() == 0

    stitcher.build_graph()
    assert stitcher.G.number_of_edges() == 9
    assert all(weight is None for *_, weight in stitcher.G.edges.data("weight"))

    stitcher.stitch()
    assert len(stitcher.tracks) == 3
    assert all(len(track) == 50 for track in stitcher.tracks)
    assert all(0.998 <= track.likelihood <= 1 for track in stitcher.tracks)

    output_name = tmpdir_factory.mktemp("data").join("fake.h5")
    stitcher.write_tracks(output_name, ["mickey", "minnie", "bianca"])