def test_stitcher_with_identity(real_tracklets): # Add fake IDs for i in range(3): tracklet = real_tracklets[i] for v in tracklet.values(): v[:, -1] = i stitcher = TrackletStitcher.from_dict_of_dict(real_tracklets, n_tracks=3) tracklets = sorted(stitcher, key=lambda t: t.identity) assert all(tracklet.identity == i for i, tracklet in enumerate(tracklets)) # Split all tracklets in half tracklets = [ t for track in stitcher for t in stitcher.split_tracklet(track, [25]) ] stitcher = TrackletStitcher(tracklets, n_tracks=3) assert len(stitcher) == 6 stitcher.build_graph() weight = stitcher.G.edges[('0out', '3in')]['weight'] def weight_func(t1, t2): w = 0.01 if t1.identity == t2.identity else 1 return w * t1.distance_to(t2) stitcher.build_graph(weight_func=weight_func) assert stitcher.G.number_of_edges() == 27 new_weight = stitcher.G.edges[('0out', '3in')]['weight'] assert new_weight == weight // 100 stitcher.stitch() assert len(stitcher.tracks) == 3 assert all(len(track) == 50 for track in stitcher.tracks) assert all(0.998 <= track.likelihood <= 1 for track in stitcher.tracks) tracks = sorted(stitcher.tracks, key=lambda t: t.identity) assert all(track.identity == i for i, track in enumerate(tracks))
def test_stitcher_montblanc(real_tracklets_montblanc): stitcher = TrackletStitcher.from_dict_of_dict( real_tracklets_montblanc, n_tracks=3, ) assert len(stitcher) == 5 assert all(tracklet.is_continuous for tracklet in stitcher.tracklets) assert all(tracklet.identity == -1 for tracklet in stitcher.tracklets) assert len(stitcher.residuals) == 1 assert len(stitcher.residuals[0]) == 2 assert stitcher.compute_max_gap(stitcher.tracklets) == 5 stitcher.build_graph() assert stitcher.G.number_of_edges() == 18 weights = [w for *_, w in stitcher.G.edges.data("weight") if w] assert weights == [2453, 24498, 5428] stitcher.stitch() assert len(stitcher.tracks) == 3 assert all(len(track) >= 176 for track in stitcher.tracks) assert all(0.996 <= track.likelihood <= 1 for track in stitcher.tracks) df_gt = pd.read_hdf('tests/data/montblanc_tracks.h5') df = stitcher.format_df() np.testing.assert_equal(df.to_numpy(), df_gt.to_numpy())
def test_stitcher_real(tmpdir_factory, real_tracklets): stitcher = TrackletStitcher.from_dict_of_dict(real_tracklets, n_tracks=3) assert len(stitcher) == 3 assert all(tracklet.is_continuous for tracklet in stitcher.tracklets) assert not stitcher.residuals assert stitcher.compute_max_gap() == 0 stitcher.build_graph() assert stitcher.G.number_of_edges() == 9 assert all(weight is None for *_, weight in stitcher.G.edges.data("weight")) stitcher.stitch() assert len(stitcher.tracks) == 3 assert all(len(track) == 50 for track in stitcher.tracks) assert all(0.998 <= track.likelihood <= 1 for track in stitcher.tracks) output_name = tmpdir_factory.mktemp("data").join("fake.h5") stitcher.write_tracks(output_name, ["mickey", "minnie", "bianca"])