def test_can_load_min_len(self): trackfile = self.tempdir / 'test.csv' chains = tracking.link_all_chains(TRACKS, merge_fxn='pairwise', processes=1) tracking.save_track_csvfile(trackfile, chains) load_chains = tracking.load_track_csvfile(trackfile, min_len=3) exp_chains = [ [1.0, 1.1, 1.2, 1.3, 1.4], [2.0, 2.1, 2.2], [3.0, 3.1, 3.2, 3.3, 3.4], [4.0, 4.1, 4.2, 4.3], [5.2, 5.3, 5.4], ] self.assertEqual(len(load_chains), len(exp_chains)) for lc, exp_coords in zip(load_chains, exp_chains): self.assertEqual(lc.line_x, exp_coords) self.assertEqual(lc.line_y, exp_coords) load_chains = tracking.load_track_csvfile(trackfile, min_len=4) exp_chains = [ [1.0, 1.1, 1.2, 1.3, 1.4], [3.0, 3.1, 3.2, 3.3, 3.4], [4.0, 4.1, 4.2, 4.3], ] self.assertEqual(len(load_chains), len(exp_chains)) for lc, exp_coords in zip(load_chains, exp_chains): self.assertEqual(lc.line_x, exp_coords) self.assertEqual(lc.line_y, exp_coords)
def test_links_tracks_different_step(self): chains = tracking.link_all_chains(TRACKS, link_step=2, merge_fxn='pairwise', processes=1) chains = list(sorted(chains, key=lambda c: c.line_x[0])) chain_coords = [ [1.0, 1.2, 1.4], [1.1, 1.3], [2.0, 2.2, 2.4], [2.1], [3.0, 3.2, 3.4], [3.1, 3.3], [4.0, 4.2], [4.1, 4.3], [5.0, 5.2, 5.4], [5.3], ] self.assertEqual(len(chains), len(chain_coords)) for coords, chain in zip(chain_coords, chains): self.assertEqual(coords, chain.line_x) self.assertEqual(coords, chain.line_y)
def test_can_track_save_load(self): trackfile = self.tempdir / 'test.csv' chains = tracking.link_all_chains(TRACKS, merge_fxn='pairwise', processes=1) tracking.save_track_csvfile(trackfile, chains) load_chains = tracking.load_track_csvfile(trackfile) self.assertEqual(len(load_chains), len(chains)) for lc, c in zip(load_chains, chains): self.assertEqual(lc, c)
def test_links_default_tracks_cluster(self): chains = tracking.link_all_chains(TRACKS, merge_fxn='cluster', max_merge_dist=4.0, impute_steps=3, processes=1) chains = list(sorted(chains, key=lambda c: c.line_x[0])) chain_coords = [ [1.0], [1.4], [2.0], [2.4], [3.0, 3.1, 4.2, 3.55, 3.4], [4.0], [5.0], [5.4], ] self.assertEqual(len(chains), len(chain_coords)) for coords, chain in zip(chain_coords, chains): fmt = 'Mismatched coordinates\n expected {}\n got {}\n\n' msg = fmt.format(coords, chain.line_x) self.assertEqual(len(coords), len(chain.line_x), msg=msg) self.assertTrue(all([ round(c, 2) == round(x, 2) for c, x in zip(coords, chain.line_x) ]), msg=msg) fmt = 'Mismatched coordinates\n expected {}\n got {}\n\n' msg = fmt.format(coords, chain.line_y) self.assertEqual(len(coords), len(chain.line_y), msg=msg) self.assertTrue(all([ round(c, 2) == round(y, 2) for c, y in zip(coords, chain.line_y) ]), msg=msg)
def test_gets_sane_results_for_velocity(self): chains = tracking.link_all_chains(TRACKS, merge_fxn='pairwise', processes=1) chains = list(sorted(chains, key=lambda c: c.line_x[0])) chain_vels = [ np.array([0.1, 0.1, 0.1, 0.1]), np.array([0.1, 0.1]), np.array([0.0]), np.array([0.1, 0.1, 0.1, 0.1]), np.array([0.1, 0.1, 0.1]), np.array([0.0]), np.array([0.1, 0.1]), ] self.assertEqual(len(chains), len(chain_vels)) for vel, chain in zip(chain_vels, chains): np.testing.assert_almost_equal(vel, chain.vel_x()) np.testing.assert_almost_equal(vel, chain.vel_y()) np.testing.assert_almost_equal(vel * np.sqrt(2), chain.vel_mag())