def test_init(self, dim): """Tests the safety checks""" """Expected to raise expections:""" with pytest.raises((ValueError, TypeError)): match_em.GreedyHungarianMatching() # match_dims missing with pytest.warns(UserWarning): match_em.GreedyHungarianMatching(match_dims=dim, dist_lat=1., dist_ax=1., dist_vol=1.) # unlikely comb. with pytest.warns(UserWarning): match_em.GreedyHungarianMatching(match_dims=dim) # unlikely comb. with pytest.raises(ValueError): match_em.GreedyHungarianMatching(match_dims=1) with pytest.raises(ValueError): match_em.GreedyHungarianMatching(match_dims=4)
def __init__(self, *, raw_th, em_th, xy_unit: str, img_shape, ax_th=None, vol_th=None, lat_th=None, p_aggregation='pbinom_cdf', px_size=None, match_dims=2, diag=0, pphotxyzbg_mapping=[0, 1, 2, 3, 4, -1], num_workers=0, skip_th: (None, float) = None, return_format='batch-set', sanity_check=True): """ Args: pphotxyzbg_mapping: raw_th: em_th: xy_unit: img_shape: ax_th: vol_th: lat_th: p_aggregation: px_size: match_dims: diag: num_workers: skip_th: relative fraction of the detection output to be on to skip post_processing. This is useful during training when the network has not yet converged and major parts of the detection output is white (i.e. non sparse detections). return_format: sanity_check: """ super().__init__(xy_unit=xy_unit, px_size=px_size, return_format=return_format) self.raw_th = raw_th self.em_th = em_th self.p_aggregation = p_aggregation self.match_dims = match_dims self.num_workers = num_workers self.skip_th = skip_th self.pphotxyzbg_mapping = pphotxyzbg_mapping self._filter = match_emittersets.GreedyHungarianMatching(match_dims=match_dims, dist_lat=lat_th, dist_ax=ax_th, dist_vol=vol_th).filter self._bg_calculator = decode.simulation.background.BgPerEmitterFromBgFrame(filter_size=13, xextent=(0., 1.), yextent=(0., 1.), img_shape=img_shape) self._neighbor_kernel = torch.tensor([[diag, 1, diag], [1, 1, 1], [diag, 1, diag]]).float().view(1, 1, 3, 3) self._clusterer = AgglomerativeClustering(n_clusters=None, distance_threshold=lat_th if self.match_dims == 2 else vol_th, affinity='precomputed', linkage='single') if sanity_check: self.sanity_check()
def test_filter_kernel_hand(self): """Setup""" matcher = match_em.GreedyHungarianMatching(match_dims=2, dist_lat=2., dist_ax=None, dist_vol=None) """Run""" filter = matcher.filter( torch.zeros((4, 3)), torch.tensor([[1.9, 0., 0.], [2.1, 0., 0.], [0., 0., -5000.], [1.5, 1.5, 0.]])) """Assert""" assert filter[:, 0].all() assert not filter[:, 1].all() assert filter[:, 2].all() assert not filter[:, 3].all()
def test_filter_kernel_statistical(self, dist_lat, dist_ax, dist_vol): """Setup""" matcher = match_em.GreedyHungarianMatching(match_dims=2, dist_lat=dist_lat, dist_ax=dist_ax, dist_vol=dist_vol) n_out = 1000 n_tar = 1200 xyz_out = torch.rand((10, n_out, 3)) * torch.tensor( [500, 500, 1000]).unsqueeze(0).unsqueeze(0) # batch implementation xyz_tar = torch.rand((10, n_tar, 3)) * torch.tensor( [500, 500, 1000]).unsqueeze(0).unsqueeze(0) # batch implementation """Run""" act = matcher.filter(xyz_out, xyz_tar) # active pairs """Asserts""" self.assert_dists(xyz_out, xyz_tar, dist_lat, dist_ax, dist_vol, act)
def test_match_kernel(self, match_dims, xyz_out, xyz_tar, expected): """Setup""" matcher = match_em.GreedyHungarianMatching(match_dims=match_dims, dist_lat=1., dist_ax=2., dist_vol=None) """Run""" filter_mask = matcher.filter(xyz_out.unsqueeze(0), xyz_tar.unsqueeze(0)) assignment = matcher._match_kernel(xyz_out, xyz_tar, filter_mask.squeeze(0)) tp_ix_out, tp_match_ix_out = assignment[2:] tp_ix_exp, tp_match_ix_exp = expected """Assert""" assert ( tp_ix_out.nonzero() == tp_ix_exp).all() # boolean index in output assert (tp_match_ix_out.nonzero() == tp_match_ix_exp).all()
def matcher(self): return match_em.GreedyHungarianMatching(match_dims=2)