def fake_stitcher(): inds = np.arange(TRACKLET_LEN) data = np.random.rand(inds.size, N_DETS, 3) track = Tracklet(data, inds) idx = np.linspace(0, inds.size, N_TRACKLETS + 1, dtype=int) tracklets = TrackletStitcher.split_tracklet(track, idx[1:-1]) return TrackletStitcher(tracklets, n_tracks=2)
def test_purify_tracklets(fake_tracklet): fake_tracklet.data = np.full_like(fake_tracklet.data, np.nan) assert TrackletStitcher.purify_tracklet(fake_tracklet) is None fake_tracklet.data[0] = 1 tracklet = TrackletStitcher.purify_tracklet(fake_tracklet) assert len(tracklet) == 1 assert tracklet.inds == fake_tracklet.inds[0]
def test_stitcher_wrong_inputs(fake_tracklet): with pytest.raises(IOError): _ = TrackletStitcher([], n_tracks=2) with pytest.raises(ValueError): _ = TrackletStitcher([fake_tracklet], n_tracks=1) with pytest.raises(ValueError): _ = TrackletStitcher([fake_tracklet], n_tracks=2, min_length=2)
def test_purify_tracklets(tracklet): tracklet.data = np.full_like(tracklet.data, np.nan) assert TrackletStitcher.purify_tracklet(tracklet) is None tracklet.data[0] = 1 tracklet_pure = TrackletStitcher.purify_tracklet(tracklet) assert len(tracklet_pure) == 1 assert tracklet_pure.inds == tracklet.inds[0]
def test_stitcher_with_identity(real_tracklets): # Add fake IDs for i in range(3): tracklet = real_tracklets[i] for v in tracklet.values(): v[:, -1] = i stitcher = TrackletStitcher.from_dict_of_dict(real_tracklets, n_tracks=3) tracklets = sorted(stitcher, key=lambda t: t.identity) assert all(tracklet.identity == i for i, tracklet in enumerate(tracklets)) # Split all tracklets in half tracklets = [ t for track in stitcher for t in stitcher.split_tracklet(track, [25]) ] stitcher = TrackletStitcher(tracklets, n_tracks=3) assert len(stitcher) == 6 stitcher.build_graph() weight = stitcher.G.edges[('0out', '3in')]['weight'] def weight_func(t1, t2): w = 0.01 if t1.identity == t2.identity else 1 return w * t1.distance_to(t2) stitcher.build_graph(weight_func=weight_func) assert stitcher.G.number_of_edges() == 27 new_weight = stitcher.G.edges[('0out', '3in')]['weight'] assert new_weight == weight // 100 stitcher.stitch() assert len(stitcher.tracks) == 3 assert all(len(track) == 50 for track in stitcher.tracks) assert all(0.998 <= track.likelihood <= 1 for track in stitcher.tracks) tracks = sorted(stitcher.tracks, key=lambda t: t.identity) assert all(track.identity == i for i, track in enumerate(tracks))
def test_stitcher_montblanc(real_tracklets_montblanc): stitcher = TrackletStitcher.from_dict_of_dict( real_tracklets_montblanc, n_tracks=3, ) assert len(stitcher) == 5 assert all(tracklet.is_continuous for tracklet in stitcher.tracklets) assert all(tracklet.identity == -1 for tracklet in stitcher.tracklets) assert len(stitcher.residuals) == 1 assert len(stitcher.residuals[0]) == 2 assert stitcher.compute_max_gap(stitcher.tracklets) == 5 stitcher.build_graph() assert stitcher.G.number_of_edges() == 18 weights = [w for *_, w in stitcher.G.edges.data("weight") if w] assert weights == [2453, 24498, 5428] stitcher.stitch() assert len(stitcher.tracks) == 3 assert all(len(track) >= 176 for track in stitcher.tracks) assert all(0.996 <= track.likelihood <= 1 for track in stitcher.tracks) df_gt = pd.read_hdf('tests/data/montblanc_tracks.h5') df = stitcher.format_df() np.testing.assert_equal(df.to_numpy(), df_gt.to_numpy())
def test_stitcher_real(tmpdir_factory, real_tracklets): stitcher = TrackletStitcher.from_dict_of_dict(real_tracklets, n_tracks=3) assert len(stitcher) == 3 assert all(tracklet.is_continuous for tracklet in stitcher.tracklets) assert not stitcher.residuals assert stitcher.compute_max_gap() == 0 stitcher.build_graph() assert stitcher.G.number_of_edges() == 9 assert all(weight is None for *_, weight in stitcher.G.edges.data("weight")) stitcher.stitch() assert len(stitcher.tracks) == 3 assert all(len(track) == 50 for track in stitcher.tracks) assert all(0.998 <= track.likelihood <= 1 for track in stitcher.tracks) output_name = tmpdir_factory.mktemp("data").join("fake.h5") stitcher.write_tracks(output_name, ["mickey", "minnie", "bianca"])
def generate_train_triplets_from_pickle(path_to_track, n_triplets=1000): ts = TrackletStitcher.from_pickle(path_to_track, 3) triplets = ts.mine(n_triplets) assert len(triplets) == n_triplets return triplets